1010import zarr
1111from scipy .spatial .distance import dice
1212from scipy .ndimage import distance_transform_edt
13- from fastremap import remap , unique
13+ from fastremap import remap , unique , renumber
1414import cc3d
1515from cc3d .types import StatisticsDict , StatisticsSlicesDict
1616
3939MAX_SEMANTIC_THREADS = int (os .getenv ("MAX_SEMANTIC_THREADS" , 25 ))
4040PER_INSTANCE_THREADS = int (os .getenv ("PER_INSTANCE_THREADS" , 25 ))
4141MAX_DISTANCE_CAP_EPS = float (os .getenv ("MAX_DISTANCE_CAP_EPS" , "1e-4" ))
42- INSTANCE_RATIO_CUTOFF = float (os .getenv ("INSTANCE_RATIO_CUTOFF" , 10 ))
42+ FINAL_INSTANCE_RATIO_CUTOFF = float (os .getenv ("FINAL_INSTANCE_RATIO_CUTOFF" , 10 ))
43+ INITIAL_INSTANCE_RATIO_CUTOFF = float (os .getenv ("INITIAL_INSTANCE_RATIO_CUTOFF" , 50 ))
44+ INSTANCE_RATIO_FACTOR = float (os .getenv ("INSTANCE_RATIO_FACTOR" , 5.0 ))
4345MAX_OVERLAP_EDGES = int (os .getenv ("MAX_OVERLAP_EDGES" , "5000000" ))
4446
4547
48+ def ratio_cutoff (
49+ nG : int ,
50+ R_base : float = FINAL_INSTANCE_RATIO_CUTOFF ,
51+ R_extra : float = INITIAL_INSTANCE_RATIO_CUTOFF ,
52+ k : float = INSTANCE_RATIO_FACTOR ,
53+ ) -> float :
54+ # nG==0 handled upstream (ratio undefined); return max tolerance for completeness
55+ return float (R_base + R_extra * np .exp (- nG / k ))
56+
57+
4658def normalize_distance (distance : float , voxel_size ) -> float :
4759 """
4860 Normalize a distance value to [0, 1] using the maximum distance represented by a voxel
@@ -86,7 +98,7 @@ def match_instances(gt: np.ndarray, pred: np.ndarray) -> dict | None:
8698 logging .info ("No GT or Pred instances; returning only background match." )
8799 return {0 : 0 }
88100
89- if (nP / nG ) > INSTANCE_RATIO_CUTOFF :
101+ if (nP / nG ) > ratio_cutoff ( nG ) :
90102 logging .warning (
91103 f"WARNING: Skipping { nP } instances in submission, { nG } in ground truth, "
92104 f"because there are too many instances in the submission."
@@ -412,7 +424,7 @@ def compute_hausdorff_distance_roi(
412424 if a_n == 0 and b_n == 0 :
413425 return 0.0
414426 elif a_n == 0 or b_n == 0 :
415- return np . inf
427+ return max_distance
416428
417429 vs = np .asarray (voxel_size , dtype = np .float64 )
418430 if vs .size != ndim :
@@ -446,7 +458,7 @@ def score_instance(
446458 truth_label ,
447459 voxel_size ,
448460 hausdorff_distance_max = None ,
449- ) -> dict [str , float ]:
461+ ) -> dict [str , float | str ]:
450462 """
451463 Score a single instance label volume against the ground truth instance label volume.
452464
@@ -465,13 +477,15 @@ def score_instance(
465477 logging .info ("Scoring instance segmentation..." )
466478 if hausdorff_distance_max is None :
467479 hausdorff_distance_max = compute_default_max_distance (voxel_size )
468- logging .info (
480+ logging .debug (
469481 f"Using default maximum Hausdorff distance of { hausdorff_distance_max :.2f} for voxel size { voxel_size } ."
470482 )
471483
472484 # Relabel the predicted instance labels to be consistent with the ground truth instance labels
473485 logging .info ("Relabeling predicted instance labels..." )
474- pred_label , n_pred = cc3d .connected_components (pred_label , return_N = True )
486+ # pred_label, n_pred = cc3d.connected_components(pred_label, return_N=True)
487+ pred_label , remapping = renumber (pred_label , in_place = True )
488+ n_pred = len (remapping ) - 1 # exclude background
475489
476490 # Match instances between ground truth and prediction
477491 mapping = match_instances (truth_label , pred_label )
@@ -480,9 +494,13 @@ def score_instance(
480494 # Too many instances in submission, skip scoring
481495 return {
482496 "accuracy" : 0 ,
483- "hausdorff_distance" : np .inf ,
484- "normalized_hausdorff_distance" : 0 ,
497+ "binary_accuracy" : ((truth_label > 0 ) == (pred_label > 0 )).mean (),
498+ "hausdorff_distance" : hausdorff_distance_max ,
499+ "normalized_hausdorff_distance" : normalize_distance (
500+ hausdorff_distance_max , voxel_size
501+ ),
485502 "combined_score" : 0 ,
503+ "status" : "skipped_too_many_instances" ,
486504 }
487505 elif len (mapping ) == 1 and 0 in mapping :
488506 # Only background present in both ground truth and prediction
@@ -504,23 +522,24 @@ def score_instance(
504522 hausdorff_distances = np .concatenate (
505523 [
506524 hausdorff_distances ,
507- np .full (len (unmatched_pred ), np .inf , dtype = np .float32 ),
525+ np .full (
526+ len (unmatched_pred ), hausdorff_distance_max , dtype = np .float32
527+ ),
508528 ]
509529 )
510530 else :
511531 # No predictions to match (no GT XOR no Pred instances)
512- hausdorff_distances = []
532+ hausdorff_distances = [hausdorff_distance_max ]
533+
534+ if len (hausdorff_distances ) == 0 :
535+ hausdorff_distances = [hausdorff_distance_max ]
513536
514537 # Compute the scores
515538 logging .info ("Computing accuracy score..." )
516539 accuracy = float ((truth_label == pred_label ).mean ())
517- hausdorff_dist = (
518- np .mean (hausdorff_distances ) if len (hausdorff_distances ) > 0 else np .inf
519- )
520- normalized_hausdorff_dist = (
521- np .mean ([normalize_distance (hd , voxel_size ) for hd in hausdorff_distances ])
522- if len (hausdorff_distances ) > 0
523- else 0.0
540+ hausdorff_dist = np .mean (hausdorff_distances )
541+ normalized_hausdorff_dist = np .mean (
542+ [normalize_distance (hd , voxel_size ) for hd in hausdorff_distances ]
524543 )
525544 combined_score = (accuracy * normalized_hausdorff_dist ) ** 0.5 # geometric mean
526545 logging .info (f"Accuracy: { accuracy :.4f} " )
@@ -532,6 +551,7 @@ def score_instance(
532551 "hausdorff_distance" : hausdorff_dist ,
533552 "normalized_hausdorff_distance" : normalized_hausdorff_dist ,
534553 "combined_score" : combined_score ,
554+ "status" : "scored" ,
535555 } # type: ignore
536556
537557
@@ -566,6 +586,7 @@ def score_semantic(pred_label, truth_label) -> dict[str, float]:
566586 scores = {
567587 "iou" : iou_score ,
568588 "dice_score" : dice_score if not np .isnan (dice_score ) else 1 ,
589+ "status" : "scored" ,
569590 }
570591
571592 logging .info (f"IoU: { scores ['iou' ]:.4f} " )
@@ -655,36 +676,29 @@ def score_label(
655676def empty_label_score (
656677 label , crop_name , instance_classes = INSTANCE_CLASSES , truth_path = TRUTH_PATH
657678):
679+ truth_path = UPath (truth_path )
680+ ds = zarr .open ((truth_path / crop_name / label ).path , mode = "r" )
681+ voxel_size = ds .attrs ["voxel_size" ]
658682 if label in instance_classes :
659683 truth_path = UPath (truth_path )
660684 return {
661685 "accuracy" : 0 ,
662- "hausdorff_distance" : 0 ,
686+ "hausdorff_distance" : compute_default_max_distance ( voxel_size ) ,
663687 "normalized_hausdorff_distance" : 0 ,
664688 "combined_score" : 0 ,
665- "num_voxels" : int (
666- np .prod (
667- zarr .open ((truth_path / crop_name / label ).path , mode = "r" ).shape
668- )
669- ),
670- "voxel_size" : zarr .open (
671- (truth_path / crop_name / label ).path , mode = "r"
672- ).attrs ["voxel_size" ],
689+ "num_voxels" : int (np .prod (ds .shape )),
690+ "voxel_size" : voxel_size ,
673691 "is_missing" : True ,
692+ "status" : "missing" ,
674693 }
675694 else :
676695 return {
677696 "iou" : 0 ,
678697 "dice_score" : 0 ,
679- "num_voxels" : int (
680- np .prod (
681- zarr .open ((truth_path / crop_name / label ).path , mode = "r" ).shape
682- )
683- ),
684- "voxel_size" : zarr .open (
685- (truth_path / crop_name / label ).path , mode = "r"
686- ).attrs ["voxel_size" ],
698+ "num_voxels" : int (np .prod (ds .shape )),
699+ "voxel_size" : voxel_size ,
687700 "is_missing" : True ,
701+ "status" : "missing" ,
688702 }
689703
690704
@@ -755,55 +769,32 @@ def get_evaluation_args(
755769
756770
757771def missing_volume_score (
758- truth_volume_path , instance_classes = INSTANCE_CLASSES
772+ truth_path , volume , instance_classes = INSTANCE_CLASSES
759773) -> list [tuple ]:
760774 """
761775 Score a missing volume as 0's, congruent with the score_volume function.
762776
763777 Args:
764- truth_volume_path (str): The path to the ground truth volume.
778+ truth_path (str): The path to the ground truth volume.
779+ volume (str): The name of the volume.
780+ instance_classes (list): A list of instance classes.
765781
766782 Returns:
767783 dict: A dictionary of scores for the volume.
768784
769785 Example usage:
770786 scores = missing_volume_score('truth.zarr/test_volume')
771787 """
772- logging .info (f"Scoring missing volume { truth_volume_path } ..." )
773- truth_volume_path = UPath (truth_volume_path )
788+ logging .info (f"Scoring missing volume { volume } ..." )
789+ truth_path = UPath (truth_path )
790+ truth_volume_path = truth_path / volume
774791
775792 # Find labels to score
776793 truth_labels = [a for a in ensure_zgroup (truth_volume_path ).array_keys ()]
777794
778795 # Score each label
779796 scores = {
780- label : (
781- {
782- "accuracy" : 0.0 ,
783- "hausdorff_distance" : 0.0 ,
784- "normalized_hausdorff_distance" : 0.0 ,
785- "combined_score" : 0.0 ,
786- "num_voxels" : int (
787- np .prod (zarr .open ((truth_volume_path / label ).path , mode = "r" ).shape )
788- ),
789- "voxel_size" : zarr .open (
790- (truth_volume_path / label ).path , mode = "r"
791- ).attrs ["voxel_size" ],
792- "is_missing" : True ,
793- }
794- if label in instance_classes
795- else {
796- "iou" : 0.0 ,
797- "dice_score" : 0.0 ,
798- "num_voxels" : int (
799- np .prod (zarr .open ((truth_volume_path / label ).path , mode = "r" ).shape )
800- ),
801- "voxel_size" : zarr .open (
802- (truth_volume_path / label ).path , mode = "r"
803- ).attrs ["voxel_size" ],
804- "is_missing" : True ,
805- }
806- )
797+ label : empty_label_score (label , volume , instance_classes , truth_path )
807798 for label in truth_labels
808799 }
809800
@@ -883,16 +874,30 @@ def combine_scores(
883874 logging .info ("Computing overall scores..." )
884875 overall_instance_scores = []
885876 overall_semantic_scores = []
877+ instance_total_voxels = sum (
878+ total_voxels [label ] for label in label_scores if label in instance_classes
879+ )
880+ semantic_total_voxels = sum (
881+ total_voxels [label ] for label in label_scores if label not in instance_classes
882+ )
886883 for label in label_scores :
887884 if label in instance_classes :
888- overall_instance_scores += [label_scores [label ]["combined_score" ]]
885+ overall_instance_scores += [
886+ label_scores [label ]["combined_score" ] * total_voxels [label ]
887+ ]
889888 else :
890- overall_semantic_scores += [label_scores [label ]["iou" ]]
889+ overall_semantic_scores += [
890+ label_scores [label ]["iou" ] * total_voxels [label ]
891+ ]
891892 scores ["overall_instance_score" ] = (
892- np .nanmean (overall_instance_scores ) if overall_instance_scores else 0
893+ np .nansum (overall_instance_scores ) / instance_total_voxels
894+ if overall_instance_scores
895+ else 0
893896 )
894897 scores ["overall_semantic_score" ] = (
895- np .nanmean (overall_semantic_scores ) if overall_semantic_scores else 0
898+ np .nansum (overall_semantic_scores ) / semantic_total_voxels
899+ if overall_semantic_scores
900+ else 0
896901 )
897902 scores ["overall_score" ] = (
898903 scores ["overall_instance_score" ] * scores ["overall_semantic_score" ]
@@ -1058,7 +1063,7 @@ def score_submission(
10581063
10591064 scores = {
10601065 volume : missing_volume_score (
1061- truth_path / volume , instance_classes = instance_classes
1066+ truth_path , volume , instance_classes = instance_classes
10621067 )
10631068 for volume in missing_volumes
10641069 }
@@ -1148,6 +1153,8 @@ def sanitize_scores(scores):
11481153 for key , value in label_scores .items ():
11491154 if value is None :
11501155 continue
1156+ if isinstance (value , str ):
1157+ continue
11511158 if not np .isscalar (value ) and len (value ) == 1 :
11521159 value = value [0 ]
11531160 if np .isscalar (value ):
0 commit comments