Skip to content

Commit 13f595d

Browse files
committed
Optimize inference by using bincount instead of masking
1 parent 2d212b2 commit 13f595d

4 files changed

Lines changed: 20 additions & 24 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "deepaudio-x"
3-
version = "0.3.5"
3+
version = "0.3.6"
44
description = "DeepAudio-X: Self-supervised audio toolkit for audio classification and beyond."
55
authors = [
66
{ name = "Christos Nikou", email = "chrisnick92@gmail.com" },

src/deepaudiox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
- training, evaluation, and inference workflows
1212
"""
1313

14-
__version__ = "0.3.5"
14+
__version__ = "0.3.6"
1515

1616
# Top-level API exports
1717
from deepaudiox.datasets.audio_classification_dataset import ( # noqa: F401

src/deepaudiox/modules/baseclasses.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,23 @@ def inference_on_waveform(
133133
}
134134

135135
# Accumulate segment-level labels
136-
segment_labels = []
137-
for pred in inference["y_preds"]:
138-
segment_labels.append(index_to_class[pred])
139-
140-
# Majority vote to get final prediction
141-
unique_preds = np.unique(inference["y_preds"])
142-
# Aggregated results sorted by predicted class and mean posterior for that class, in descending order
143-
aggregated_results = sorted(
144-
[
145-
(
146-
pred,
147-
inference["y_preds"][inference["y_preds"] == pred].sum(),
148-
inference["posteriors"][inference["y_preds"] == pred].mean(),
149-
)
150-
for pred in unique_preds
151-
],
152-
key=lambda x: (x[1], x[2]),
153-
reverse=True,
154-
)
155-
# First item is the winner with highest mean posterior / handles ties by mean posterior
156-
final_winner_index, counts, final_posterior = aggregated_results[0]
136+
segment_labels = [index_to_class[pred] for pred in inference["y_preds"]]
137+
138+
# Majority vote to get final prediction (tie-break by mean posterior)
139+
y_preds = inference["y_preds"]
140+
posteriors = inference["posteriors"]
141+
142+
num_classes = len(class_mapping)
143+
counts = np.bincount(y_preds, minlength=num_classes)
144+
sum_posteriors = np.bincount(y_preds, weights=posteriors, minlength=num_classes)
145+
146+
valid = counts > 0
147+
mean_posteriors = np.zeros_like(sum_posteriors, dtype=float)
148+
mean_posteriors[valid] = sum_posteriors[valid] / counts[valid]
149+
150+
candidates = np.where(valid)[0]
151+
final_winner_index = max(candidates, key=lambda cls: (counts[cls], mean_posteriors[cls]))
152+
final_posterior = mean_posteriors[final_winner_index]
157153

158154
return AudioPrediction(
159155
final_label=index_to_class[final_winner_index],

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)