Skip to content

Commit

Permalink
Merge pull request #13 from earthspecies/Lou1sM-bidirectional
Browse files Browse the repository at this point in the history
Lou1s m bidirectional
  • Loading branch information
benjaminsshoffman authored Jul 11, 2024
2 parents 4a1dd60 + aab07d3 commit bc17309
Show file tree
Hide file tree
Showing 11 changed files with 644 additions and 557 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ For example, say you annotate your audio with the labels Red-eyed Vireo `REVI`,
Here are some additional options that can be applied during training:

- Flag `--stereo` accepts stereo audio. Order of channels matters; used for e.g. speaker diarization.
- Flag `--multichannel` accepts audio with >1 channel. Order of channels does not matter.
- Flag `--bidirectional` predicts the ends of events in addition to the beginning, matches starts and ends based on IoU. May improve box regression.
- Flag `--segmentation-based` switches to a frame-based approach. If used, we recommend putting `--rho=1`.
- Flag `--mixup` applies mixup augmentation.

Expand Down
145 changes: 75 additions & 70 deletions voxaboxen/data/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
import math
import numpy as np
import pandas as pd
import librosa
import warnings

from numpy.random import default_rng
from intervaltree import IntervalTree
Expand All @@ -16,16 +14,16 @@
def normalize_sig_np(sig, eps=1e-8):
sig = sig / (np.max(np.abs(sig))+eps)
return sig

def crop_and_pad(wav, sr, dur_sec):
# crops and pads waveform to be the expected number of samples; used after resampling to ensure proper size
target_dur_samples = int(sr * dur_sec)
wav = wav[..., :target_dur_samples]

pad = target_dur_samples - wav.size(-1)
if pad > 0:
wav = F.pad(wav, (0,pad)) #padding starts from last dims

return wav

class DetectionDataset(Dataset):
Expand All @@ -47,7 +45,7 @@ def __init__(self, info_df, train, args, random_seed_shift = 0):
if self.amp_aug:
self.amp_aug_low_r = args.amp_aug_low_r
self.amp_aug_high_r = args.amp_aug_high_r
assert (self.amp_aug_low_r >= 0) #and (self.amp_aug_high_r <= 1) and
assert (self.amp_aug_low_r >= 0) #and (self.amp_aug_high_r <= 1) and
assert (self.amp_aug_low_r <= self.amp_aug_high_r)

self.scale_factor = args.scale_factor
Expand All @@ -61,14 +59,14 @@ def __init__(self, info_df, train, args, random_seed_shift = 0):
self.mono = False
else:
self.mono = True

if self.train:
self.omit_empty_clip_prob = args.omit_empty_clip_prob
self.clip_start_offset = self.rng.integers(0, np.floor(self.clip_hop*self.sr)) / self.sr
else:
self.omit_empty_clip_prob = 0
self.clip_start_offset = 0

self.args=args
# make metadata
self.make_metadata()
Expand All @@ -89,15 +87,15 @@ def process_selection_table(self, selection_table_fp):
start = row['Begin Time (s)']
end = row['End Time (s)']
label = row['Annotation']

if end<=start:
continue

if label in self.label_mapping:
label = self.label_mapping[label]
else:
continue

if label == self.unknown_label:
label_idx = -1
else:
Expand All @@ -113,7 +111,7 @@ def make_metadata(self):
for ii, row in self.info_df.iterrows():
fn = row['fn']
audio_fp = row['audio_fp']

duration = librosa.get_duration(path=audio_fp)
selection_table_fp = row['selection_table_fp']

Expand Down Expand Up @@ -144,10 +142,10 @@ def get_pos_intervals(self, fn, start, end):
intervals = [(max(iv.begin, start)-start, min(iv.end, end)-start, iv.data) for iv in intervals]

return intervals

def get_class_proportions(self):
counts = np.zeros((self.n_classes,))

for k in self.selection_table_dict:
st = self.selection_table_dict[k]
for interval in st:
Expand All @@ -156,98 +154,104 @@ def get_class_proportions(self):
continue
else:
counts[annot] += 1

total_count = np.sum(counts)
proportions = counts / total_count

return proportions


def get_annotation(self, pos_intervals, audio):

raw_seq_len = audio.shape[-1]
seq_len = int(math.ceil(raw_seq_len / self.scale_factor_raw_to_prediction))
regression_anno = np.zeros((seq_len,))
class_anno = np.zeros((seq_len, self.n_classes))

anno_sr = int(self.sr // self.scale_factor_raw_to_prediction)


regression_annos = np.zeros((seq_len,))
class_annos = np.zeros((seq_len, self.n_classes))
anchor_annos = [np.zeros(seq_len,)]
rev_regression_annos = np.zeros((seq_len,))
rev_class_annos = np.zeros((seq_len, self.n_classes))
rev_anchor_annos = [np.zeros(seq_len,)]

for iv in pos_intervals:
start, end, class_idx = iv
dur = end-start

dur_samples = np.ceil(dur * anno_sr)

start_idx = int(math.floor(start*anno_sr))
start_idx = max(min(start_idx, seq_len-1), 0)
dur_samples = int(np.ceil(dur * anno_sr))

end_idx = int(math.ceil(end*anno_sr))
end_idx = max(min(end_idx, seq_len-1), 0)
dur_samples = int(np.ceil(dur * anno_sr))

anchor_anno = get_anchor_anno(start_idx, dur_samples, seq_len)
anchor_annos.append(anchor_anno)
regression_anno[start_idx] = dur
regression_annos[start_idx] = dur

rev_anchor_anno = get_anchor_anno(end_idx, dur_samples, seq_len)
rev_anchor_annos.append(rev_anchor_anno)
rev_regression_annos[end_idx] = dur

if hasattr(self.args,"segmentation_based") and self.args.segmentation_based:
if class_idx == -1:
pass
else:
class_anno[start_idx:start_idx+dur_samples,class_idx]=1.
class_annos[start_idx:start_idx+dur_samples,class_idx]=1.

else:
if class_idx != -1:
class_anno[start_idx, class_idx] = 1.
class_annos[start_idx, class_idx] = 1.
rev_class_annos[end_idx, class_idx] = 1.
else:
class_anno[start_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty

class_annos[start_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty
rev_class_annos[end_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty


anchor_annos = np.stack(anchor_annos)
anchor_annos = np.amax(anchor_annos, axis = 0)

return anchor_annos, regression_anno, class_anno # shapes [time_steps, ], [time_steps, ], [time_steps, n_classes]
rev_anchor_annos = np.stack(rev_anchor_annos)
rev_anchor_annos = np.amax(rev_anchor_annos, axis = 0)
# shapes [time_steps, ], [time_steps, ], [time_steps, n_classes] (times two)
return anchor_annos, regression_annos, class_annos, rev_anchor_annos, rev_regression_annos, rev_class_annos

def __getitem__(self, index):
fn, audio_fp, start, end = self.metadata[index]
audio, file_sr = librosa.load(audio_fp, sr=None, offset=start, duration=self.clip_duration, mono=self.mono)

audio, file_sr = librosa.load(audio_fp, sr=None, offset=start, duration=self.clip_duration, mono=self.mono)
audio = torch.from_numpy(audio)

audio = audio-torch.mean(audio, -1, keepdim=True)
if self.amp_aug and self.train:
audio = self.augment_amplitude(audio)
if file_sr != self.sr:
audio = torchaudio.functional.resample(audio, file_sr, self.sr)
audio = torchaudio.functional.resample(audio, file_sr, self.sr)

audio = crop_and_pad(audio, self.sr, self.clip_duration)

pos_intervals = self.get_pos_intervals(fn, start, end)
anchor_anno, regression_anno, class_anno = self.get_annotation(pos_intervals, audio)
anchor_anno, regression_anno, class_anno, rev_anchor_anno, rev_regression_anno, rev_class_anno = self.get_annotation(pos_intervals, audio)

return audio, torch.from_numpy(anchor_anno), torch.from_numpy(regression_anno), torch.from_numpy(class_anno)
return audio, torch.from_numpy(anchor_anno), torch.from_numpy(regression_anno), torch.from_numpy(class_anno), torch.from_numpy(rev_anchor_anno), torch.from_numpy(rev_regression_anno), torch.from_numpy(rev_class_anno)

def __len__(self):
return len(self.metadata)


def get_train_dataloader(args, random_seed_shift = 0):
train_info_fp = args.train_info_fp
train_info_df = pd.read_csv(train_info_fp)

train_dataset = DetectionDataset(train_info_df, True, args, random_seed_shift = random_seed_shift)

# if args.mixup:
# effective_batch_size = args.batch_size*2 # double batch size because half will be discarded before being passed to model
# else:
# effective_batch_size = args.batch_size


train_dataloader = DataLoader(train_dataset,
batch_size=args.batch_size, #effective_batch_size,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
pin_memory=True,
drop_last = True)

return train_dataloader


class SingleClipDataset(Dataset):
def __init__(self, audio_fp, clip_hop, args, annot_fp = None):
# waveform (samples,)
Expand All @@ -265,26 +269,26 @@ def __init__(self, audio_fp, clip_hop, args, annot_fp = None):
self.mono = False
else:
self.mono = True

def __len__(self):
return self.num_clips

def __getitem__(self, idx):
""" Map int idx to dict of torch tensors """
start = idx * self.clip_hop

audio, file_sr = librosa.load(self.audio_fp, sr=None, offset=start, duration=self.clip_duration, mono=self.mono)
audio = torch.from_numpy(audio)


audio = audio-torch.mean(audio, -1, keepdim=True)
if file_sr != self.sr:
audio = torchaudio.functional.resample(audio, file_sr, self.sr)
audio = torchaudio.functional.resample(audio, file_sr, self.sr)

audio = crop_and_pad(audio, self.sr, self.clip_duration)

return audio

def get_single_clip_data(audio_fp, clip_hop, args, annot_fp = None):
return DataLoader(
SingleClipDataset(audio_fp, clip_hop, args, annot_fp = annot_fp),
Expand All @@ -296,33 +300,33 @@ def get_single_clip_data(audio_fp, clip_hop, args, annot_fp = None):
)

def get_val_dataloader(args):
val_info_fp = args.val_info_fp
val_info_fp = args.val_info_fp
val_info_df = pd.read_csv(val_info_fp)

val_dataloaders = {}

for i in range(len(val_info_df)):
fn = val_info_df.iloc[i]['fn']
audio_fp = val_info_df.iloc[i]['audio_fp']
annot_fp = val_info_df.iloc[i]['selection_table_fp']
val_dataloaders[fn] = get_single_clip_data(audio_fp, args.clip_duration/2, args, annot_fp = annot_fp)

return val_dataloaders

def get_test_dataloader(args):
test_info_fp = args.test_info_fp
test_info_df = pd.read_csv(test_info_fp)

test_dataloaders = {}

for i in range(len(test_info_df)):
fn = test_info_df.iloc[i]['fn']
audio_fp = test_info_df.iloc[i]['audio_fp']
annot_fp = test_info_df.iloc[i]['selection_table_fp']
test_dataloaders[fn] = get_single_clip_data(audio_fp, args.clip_duration/2, args, annot_fp = annot_fp)

return test_dataloaders

def get_anchor_anno(start_idx, dur_samples, seq_len):
# start times plus gaussian blur
# std setting follows CornerNet, where adaptive standard deviation is set to 1/3 image radius
Expand All @@ -331,4 +335,5 @@ def get_anchor_anno(start_idx, dur_samples, seq_len):
x = x / (2 * std**2)
x = np.exp(-x)
return x



Loading

0 comments on commit bc17309

Please sign in to comment.