@@ -133,27 +133,23 @@ def inference_on_waveform(
133133 }
134134
135135 # Accumulate segment-level labels
136- segment_labels = []
137- for pred in inference ["y_preds" ]:
138- segment_labels .append (index_to_class [pred ])
139-
140- # Majority vote to get final prediction
141- unique_preds = np .unique (inference ["y_preds" ])
142- # Aggregated results sorted by predicted class and mean posterior for that class, in descending order
143- aggregated_results = sorted (
144- [
145- (
146- pred ,
147- inference ["y_preds" ][inference ["y_preds" ] == pred ].sum (),
148- inference ["posteriors" ][inference ["y_preds" ] == pred ].mean (),
149- )
150- for pred in unique_preds
151- ],
152- key = lambda x : (x [1 ], x [2 ]),
153- reverse = True ,
154- )
155- # First item is the winner with highest mean posterior / handles ties by mean posterior
156- final_winner_index , counts , final_posterior = aggregated_results [0 ]
136+ segment_labels = [index_to_class [pred ] for pred in inference ["y_preds" ]]
137+
138+ # Majority vote to get final prediction (tie-break by mean posterior)
139+ y_preds = inference ["y_preds" ]
140+ posteriors = inference ["posteriors" ]
141+
142+ num_classes = len (class_mapping )
143+ counts = np .bincount (y_preds , minlength = num_classes )
144+ sum_posteriors = np .bincount (y_preds , weights = posteriors , minlength = num_classes )
145+
146+ valid = counts > 0
147+ mean_posteriors = np .zeros_like (sum_posteriors , dtype = float )
148+ mean_posteriors [valid ] = sum_posteriors [valid ] / counts [valid ]
149+
150+ candidates = np .where (valid )[0 ]
151+ final_winner_index = max (candidates , key = lambda cls : (counts [cls ], mean_posteriors [cls ]))
152+ final_posterior = mean_posteriors [final_winner_index ]
157153
158154 return AudioPrediction (
159155 final_label = index_to_class [final_winner_index ],
0 commit comments