diff --git a/README.md b/README.md index b3fa7dc..7c69d9f 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,15 @@ After running `python main.py project-setup`, a `project_config.yaml` file will For example, say you annotate your audio with the labels Red-eyed Vireo `REVI`, Philidelphia Vireo`PHVI`, and Unsure `REVI/PHVI`. To reflect your uncertainty about `REVI/PHVI`, your `label_set` would include `REVI` and `PHVI`, and your `label_mapping` would include the pairs `REVI: REVI`, `PHVI: PHVI`, and `REVI/PHVI: Unknown`. Alternatively, you could group both types of Vireo together by making your `label_set` only include `Vireo`, and your `label_mapping` include `REVI: Vireo`, `PHVI: Vireo`, `REVI/PHVI: Vireo`. +## Other features + +Here are some additional options that can be applied during training: + +- Flag `--stereo` accepts stereo audio. Order of channels matters; used for e.g. speaker diarization. +- Flag `--multichannel` accepts audio with >1 channel. Order of channels does not matter. +- Flag `--segmentation-based` switches to a frame-based approach. If used, we recommend putting `--rho=1`. +- Flag `--mixup` applies mixup augmentation. + ## The name Voxaboxen is designed to put a *box* around each vocalization (*vox*). It also rhymes with [Roxaboxen](https://www.thriftbooks.com/w/roxaboxen_alice-mclerran/331707/). diff --git a/voxaboxen/data/data.py b/voxaboxen/data/data.py index 620c614..5f29e66 100644 --- a/voxaboxen/data/data.py +++ b/voxaboxen/data/data.py @@ -57,6 +57,8 @@ def __init__(self, info_df, train, args, random_seed_shift = 0): self.train=train if hasattr(args, 'stereo') and args.stereo: self.mono = False + elif hasattr(args, 'multichannel') and args.multichannel: + self.mono = False else: self.mono = True @@ -66,7 +68,8 @@ def __init__(self, info_df, train, args, random_seed_shift = 0): else: self.omit_empty_clip_prob = 0 self.clip_start_offset = 0 - + + self.args=args # make metadata self.make_metadata() @@ -177,16 +180,23 @@ def get_annotation(self, pos_intervals, audio): start_idx = int(math.floor(start*anno_sr)) start_idx = max(min(start_idx, seq_len-1), 0) - dur_samples = np.ceil(dur * anno_sr) + dur_samples = int(np.ceil(dur * anno_sr)) anchor_anno = get_anchor_anno(start_idx, dur_samples, seq_len) anchor_annos.append(anchor_anno) regression_anno[start_idx] = dur - - if class_idx != -1: - class_anno[start_idx, class_idx] = 1. + + if hasattr(self.args,"segmentation_based") and self.args.segmentation_based: + if class_idx == -1: + pass + else: + class_anno[start_idx:start_idx+dur_samples,class_idx]=1. + else: - class_anno[start_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty + if class_idx != -1: + class_anno[start_idx, class_idx] = 1. + else: + class_anno[start_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty anchor_annos = np.stack(anchor_annos) anchor_annos = np.amax(anchor_annos, axis = 0) @@ -222,14 +232,14 @@ def get_train_dataloader(args, random_seed_shift = 0): train_dataset = DetectionDataset(train_info_df, True, args, random_seed_shift = random_seed_shift) - if args.mixup: - effective_batch_size = args.batch_size*2 # double batch size because half will be discarded before being passed to model - else: - effective_batch_size = args.batch_size + # if args.mixup: + # effective_batch_size = args.batch_size*2 # double batch size because half will be discarded before being passed to model + # else: + # effective_batch_size = args.batch_size train_dataloader = DataLoader(train_dataset, - batch_size=effective_batch_size, + batch_size=args.batch_size, #effective_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, @@ -251,6 +261,8 @@ def __init__(self, audio_fp, clip_hop, args, annot_fp = None): self.sr = args.sr if hasattr(args, 'stereo') and args.stereo: self.mono = False + elif hasattr(args, 'multichannel') and args.multichannel: + self.mono = False else: self.mono = True diff --git a/voxaboxen/evaluation/evaluation.py b/voxaboxen/evaluation/evaluation.py index 4896ca1..f623652 100644 --- a/voxaboxen/evaluation/evaluation.py +++ b/voxaboxen/evaluation/evaluation.py @@ -126,7 +126,10 @@ def generate_predictions(model, single_clip_dataloader, args, verbose = True): X, _, _, _ = rms_and_mixup(X, None, None, None, False, args) detection, regression, classification = model(X) - classification = torch.nn.functional.softmax(classification, dim=-1) + if hasattr(args, "segmentation_based") and args.segmentation_based: + classification=torch.nn.functional.sigmoid(classification) + else: + classification=torch.nn.functional.softmax(classification, dim=-1) all_detections.append(detection) all_regressions.append(regression) @@ -200,51 +203,109 @@ def generate_features(model, single_clip_dataloader, args, verbose = True): return all_features.detach().cpu().numpy() +def fill_holes(m, max_hole): + stops = m[:-1] * ~m[1:] + stops = np.where(stops)[0] + + for stop in stops: + look_forward = m[stop+1:stop+1+max_hole] + if np.any(look_forward): + next_start = np.amin(np.where(look_forward)[0]) + stop + 1 + m[stop : next_start] = True + + return m + +def delete_short(m, min_pos): + starts = m[1:] * ~m[:-1] + + starts = np.where(starts)[0] + 1 + + clips = [] + + for start in starts: + look_forward = m[start:] + ends = np.where(~look_forward)[0] + if len(ends)>0: + clips.append((start, start+np.amin(ends))) + else: + clips.append((start, len(m)-1)) + + m = np.zeros_like(m).astype(bool) + for clip in clips: + if clip[1] - clip[0] >= min_pos: + m[clip[0]:clip[1]] = True + + return m + def export_to_selection_table(detections, regressions, classifications, fn, args, verbose=True, target_dir=None, detection_threshold = 0.5, classification_threshold = 0): if target_dir is None: target_dir = args.experiment_output_dir - -# Debugging -# -# target_fp = os.path.join(target_dir, f"detections_{fn}.npy") -# np.save(target_fp, detections) - -# target_fp = os.path.join(target_dir, f"regressions_{fn}.npy") -# np.save(target_fp, regressions) - -# target_fp = os.path.join(target_dir, f"classifications_{fn}.npy") -# np.save(target_fp, classifications) - - ## peaks - detection_peaks, properties = find_peaks(detections, height = detection_threshold, distance=args.peak_distance) - detection_probs = properties['peak_heights'] - - ## regressions and classifications - durations = [] - class_idxs = [] - class_probs = [] - - for i in detection_peaks: - dur = regressions[i] - durations.append(dur) - - c = np.argmax(classifications[i,:]) - p = classifications[i,c] - - if p < classification_threshold: - c = -1 - - class_idxs.append(c) - class_probs.append(p) - - durations = np.array(durations) - class_idxs = np.array(class_idxs) - class_probs = np.array(class_probs) - pred_sr = args.sr // (args.scale_factor * args.prediction_scale_factor) - - bboxes, detection_probs, class_idxs, class_probs = pred2bbox(detection_peaks, detection_probs, durations, class_idxs, class_probs, pred_sr) + if hasattr(args, "segmentation_based") and args.segmentation_based: + pred_sr = args.sr // (args.scale_factor * args.prediction_scale_factor) + bboxes = [] + detection_probs = [] + class_idxs = [] + class_probs = [] + for c in range(np.shape(classifications)[1]): + classifications_sub=classifications[:,c] + classifications_sub_binary=(classifications_sub>=detection_threshold) + classifications_sub_binary=fill_holes(classifications_sub_binary,int(args.fill_holes_dur_sec*pred_sr)) + classifications_sub_binary=delete_short(classifications_sub_binary,int(args.delete_short_dur_sec*pred_sr)) + + starts = classifications_sub_binary[1:] * ~classifications_sub_binary[:-1] + starts = np.where(starts)[0] + 1 + + for start in starts: + look_forward = classifications_sub_binary[start:] + ends = np.where(~look_forward)[0] + if len(ends)>0: + end = start+np.amin(ends) + else: + end = len(classifications_sub_binary)-1 + + bbox = [start/pred_sr,end/pred_sr] + bboxes.append(bbox) + detection_probs.append(classifications_sub[start:end].mean()) + class_idxs.append(c) + class_probs.append(classifications_sub[start:end].mean()) + + bboxes=np.array(bboxes) + detection_probs=np.array(detection_probs) + class_idxs=np.array(class_idxs) + class_probs=np.array(class_probs) + + else: + ## peaks + detection_peaks, properties = find_peaks(detections, height = detection_threshold, distance=args.peak_distance) + detection_probs = properties['peak_heights'] + + ## regressions and classifications + durations = [] + class_idxs = [] + class_probs = [] + + for i in detection_peaks: + dur = regressions[i] + durations.append(dur) + + c = np.argmax(classifications[i,:]) + p = classifications[i,c] + + if p < classification_threshold: + c = -1 + + class_idxs.append(c) + class_probs.append(p) + + durations = np.array(durations) + class_idxs = np.array(class_idxs) + class_probs = np.array(class_probs) + + pred_sr = args.sr // (args.scale_factor * args.prediction_scale_factor) + + bboxes, detection_probs, class_idxs, class_probs = pred2bbox(detection_peaks, detection_probs, durations, class_idxs, class_probs, pred_sr) if args.nms == "soft_nms": bboxes, detection_probs, class_idxs, class_probs = soft_nms(bboxes, detection_probs, class_idxs, class_probs, sigma = args.soft_nms_sigma, thresh = args.detection_threshold) @@ -261,8 +322,9 @@ def export_to_selection_table(detections, regressions, classifications, fn, args return target_fp -def get_metrics(predictions_fp, annotations_fp, args, iou, class_threshold): +def get_metrics(predictions_fp, annotations_fp, args, iou, class_threshold, duration): c = Clip(label_set=args.label_set, unknown_label=args.unknown_label) + c.duration = duration c.load_predictions(predictions_fp) c.threshold_class_predictions(class_threshold) @@ -292,12 +354,12 @@ def get_confusion_matrix(predictions_fp, annotations_fp, args, iou, class_thresh def summarize_metrics(metrics): # metrics (dict) : {fp : fp_metrics} # where - # fp_metrics (dict) : {class_label: {'TP': int, 'FP' : int, 'FN' : int}} + # fp_metrics (dict) : {class_label: {'TP': int, 'FP' : int, 'FN' : int, 'TP_seg' : int, 'FP_seg' : int, 'FN_seg' : int}} fps = sorted(metrics.keys()) class_labels = sorted(metrics[fps[0]].keys()) - overall = { l: {'TP' : 0, 'FP' : 0, 'FN' : 0} for l in class_labels} + overall = { l: {'TP' : 0, 'FP' : 0, 'FN' : 0, 'TP_seg' : 0, 'FP_seg' : 0, 'FN_seg' : 0} for l in class_labels} for fp in fps: for l in class_labels: @@ -305,36 +367,60 @@ def summarize_metrics(metrics): overall[l]['TP'] += counts['TP'] overall[l]['FP'] += counts['FP'] overall[l]['FN'] += counts['FN'] + overall[l]['TP_seg'] += counts['TP_seg'] + overall[l]['FP_seg'] += counts['FP_seg'] + overall[l]['FN_seg'] += counts['FN_seg'] for l in class_labels: tp = overall[l]['TP'] fp = overall[l]['FP'] fn = overall[l]['FN'] + tp_seg = overall[l]['TP_seg'] + fp_seg = overall[l]['FP_seg'] + fn_seg = overall[l]['FN_seg'] if tp + fp == 0: prec = 1 else: prec = tp / (tp + fp) overall[l]['precision'] = prec + + if tp_seg + fp_seg == 0: + prec_seg = 1 + else: + prec_seg = tp_seg / (tp_seg + fp_seg) + overall[l]['precision_seg'] = prec_seg if tp + fn == 0: rec = 1 else: rec = tp / (tp + fn) overall[l]['recall'] = rec + + if tp_seg + fn_seg == 0: + rec_seg = 1 + else: + rec_seg = tp_seg / (tp_seg + fn_seg) + overall[l]['recall_seg'] = rec_seg if prec + rec == 0: f1 = 0 else: f1 = 2*prec*rec / (prec + rec) overall[l]['f1'] = f1 + + if prec_seg + rec_seg == 0: + f1_seg = 0 + else: + f1_seg = 2*prec_seg*rec_seg / (prec_seg + rec_seg) + overall[l]['f1_seg'] = f1_seg return overall def macro_metrics(summary): - # summary (dict) : {class_label: {'f1' : float, 'precision' : float, 'recall' : float, 'TP': int, 'FP' : int, 'FN' : int}} + # summary (dict) : {class_label: {'f1' : float, 'precision' : float, 'recall' : float, 'f1_seg' : float, 'precision_seg' : float, 'recall_seg' : float, 'TP': int, 'FP' : int, 'FN' : int, TP_seg': int, 'FP_seg' : int, 'FN_seg' : int}} - metrics = ['f1', 'precision', 'recall'] + metrics = ['f1', 'precision', 'recall', 'f1_seg', 'precision_seg', 'recall_seg'] macro = {} @@ -387,6 +473,7 @@ def predict_and_generate_manifest(model, dataloader_dict, args, verbose = True): fns = [] predictions_fps = [] annotations_fps = [] + durations = [] for fn in dataloader_dict: detections, regressions, classifications = generate_predictions(model, dataloader_dict[fn], args, verbose=verbose) @@ -398,8 +485,9 @@ def predict_and_generate_manifest(model, dataloader_dict, args, verbose = True): fns.append(fn) predictions_fps.append(predictions_fp) annotations_fps.append(annotations_fp) + durations.append(np.shape(detections)[0]*args.scale_factor/args.sr) - manifest = pd.DataFrame({'filename' : fns, 'predictions_fp' : predictions_fps, 'annotations_fp' : annotations_fps}) + manifest = pd.DataFrame({'filename' : fns, 'predictions_fp' : predictions_fps, 'annotations_fp' : annotations_fps, 'duration_sec' : durations}) return manifest def evaluate_based_on_manifest(manifest, args, output_dir = None, iou = 0.5, class_threshold = 0.0): @@ -411,8 +499,9 @@ def evaluate_based_on_manifest(manifest, args, output_dir = None, iou = 0.5, cla fn = row['filename'] predictions_fp = row['predictions_fp'] annotations_fp = row['annotations_fp'] + duration = row['duration_sec'] - metrics[fn] = get_metrics(predictions_fp, annotations_fp, args, iou, class_threshold) + metrics[fn] = get_metrics(predictions_fp, annotations_fp, args, iou, class_threshold, duration) confusion_matrix[fn], confusion_matrix_labels = get_confusion_matrix(predictions_fp, annotations_fp, args, iou, class_threshold) if output_dir is not None: diff --git a/voxaboxen/evaluation/metrics.py b/voxaboxen/evaluation/metrics.py index 1e277f8..b80d557 100644 --- a/voxaboxen/evaluation/metrics.py +++ b/voxaboxen/evaluation/metrics.py @@ -104,6 +104,57 @@ def iou(ref, est, method="fast"): S[ref_id, matching_est_id] = intersection_over_union return S + +def compute_intersection(ref, est, method="fast"): + """Compute pairwise intersection between reference + events and estimated events. + Let us denote by a_i and b_i the onset and offset of reference event i. + Let us denote by u_j and v_j the onset and offset of estimated event j. + The Intersection between events i and j is defined as + (min(b_i, v_j)-max(a_i, u_j)) + if the events are non-disjoint, and equal to zero otherwise. + Parameters + ---------- + ref: np.ndarray [shape=(2, n)], real-valued + Array of reference events. Each column is an event. + The first row denotes onset times and the second row denotes offset times. + est: np.ndarray [shape=(2, m)], real-valued + Array of estimated events. Each column is an event. + The first row denotes onset times and the second row denotes offset times. + method: str, optional. + If "fast" (default), computes pairwise intersections via a custom + dynamic programming algorithm, see fast_intersect. + If "slow", computes pairwise intersections via bruteforce quadratic + search, see slow_intersect. + Returns + ------- + S: scipy.sparse.dok.dok_matrix, real-valued + Sparse 2-D matrix. S[i,j] contains the Intersection between ref[i] and est[j] + if these events are non-disjoint and zero otherwise. + """ + n_refs = ref.shape[1] + n_ests = est.shape[1] + S = scipy.sparse.dok_matrix((n_refs, n_ests)) + + if method == "fast": + matches = fast_intersect(ref, est) + elif method == "slow": + matches = slow_intersect(ref, est) + + for ref_id in range(n_refs): + matching_ests = matches[ref_id] + ref_on = ref[0, ref_id] + ref_off = ref[1, ref_id] + + for matching_est_id in matching_ests: + est_on = est[0, matching_est_id] + est_off = est[1, matching_est_id] + intersection = min(ref_off, est_off) - max(ref_on, est_on) + # union = max(ref_off, est_off) - min(ref_on, est_on) + # intersection_over_union = intersection / union + S[ref_id, matching_est_id] = intersection #_over_union + + return S def match_events(ref, est, min_iou=0.0, method="fast"): diff --git a/voxaboxen/evaluation/raven_utils.py b/voxaboxen/evaluation/raven_utils.py index 7178452..fe5f454 100644 --- a/voxaboxen/evaluation/raven_utils.py +++ b/voxaboxen/evaluation/raven_utils.py @@ -45,6 +45,8 @@ def load_audio(self, fp): self.samples, self.sr = librosa.load(fp, sr = None) self.duration = len(self.samples) / self.sr + + def play_audio(self, start_sec, end_sec): start_sample = int(self.sr * start_sec) end_sample = int(self.sr *end_sec) @@ -82,16 +84,38 @@ def compute_matching(self, IoU_minimum = 0.5): self.matched_annotations = [p[0] for p in self.matching] self.matched_predictions = [p[1] for p in self.matching] - def evaluate(self): + def evaluate(self): + + eval_sr = 50 + dur_samples = int(self.duration * eval_sr) # compute frame-wise metrics at 50Hz if self.label_set is None: TP = len(self.matching) FP = len(self.predictions) - TP FN = len(self.annotations) - TP - return {'all' : {'TP' : TP, 'FP' : FP, 'FN' : FN}} + + # segmentation-based metrics + seg_annotations = np.zeros((dur_samples,)) + seg_predictions = np.zeros((dur_samples,)) + + for i, row in self.annotations.iterrows(): + start_sample = int(row['Begin Time (s)'] * eval_sr) + end_sample = min(int(row['End Time (s)'] * eval_sr), dur_samples) + seg_annotations[start_sample:end_sample] = 1 + + for i, row in self.predictions.iterrows(): + start_sample = int(row['Begin Time (s)'] * eval_sr) + end_sample = min(int(row['End Time (s)'] * eval_sr), dur_samples) + seg_predictions[start_sample:end_sample] = 1 + + TP_seg = int((seg_predictions * seg_annotations).sum()) + FP_seg = int((seg_predictions * (1-seg_annotations)).sum()) + FN_seg = int(((1-seg_predictions) * seg_annotations).sum()) + + return {'all' : {'TP' : TP, 'FP' : FP, 'FN' : FN, 'TP_seg' : TP_seg, 'FP_seg' : FP_seg, 'FN_seg' : FN_seg}} else: - out = {label : {'TP':0, 'FP':0, 'FN' : 0} for label in self.label_set} + out = {label : {'TP':0, 'FP':0, 'FN' : 0, 'TP_seg':0, 'FP_seg':0, 'FN_seg':0} for label in self.label_set} pred_label = np.array(self.predictions['Annotation']) annot_label = np.array(self.annotations['Annotation']) for p in self.matching: @@ -111,6 +135,30 @@ def evaluate(self): out[label]['FP'] = out[label]['FP'] + n_pred - out[label]['TP'] out[label]['FN'] = out[label]['FN'] + n_annot - out[label]['TP'] + # segmentation-based metrics + seg_annotations = np.zeros((dur_samples,)) + seg_predictions = np.zeros((dur_samples,)) + + annot_sub = self.annotations[self.annotations["Annotation"] == label] + pred_sub = self.predictions[self.predictions["Annotation"] == label] + + for i, row in annot_sub.iterrows(): + start_sample = int(row['Begin Time (s)'] * eval_sr) + end_sample = min(int(row['End Time (s)'] * eval_sr), dur_samples) + seg_annotations[start_sample:end_sample] = 1 + + for i, row in pred_sub.iterrows(): + start_sample = int(row['Begin Time (s)'] * eval_sr) + end_sample = min(int(row['End Time (s)'] * eval_sr), dur_samples) + seg_predictions[start_sample:end_sample] = 1 + + TP_seg = int((seg_predictions * seg_annotations).sum()) + FP_seg = int((seg_predictions * (1-seg_annotations)).sum()) + FN_seg = int(((1-seg_predictions) * seg_annotations).sum()) + out[label]['TP_seg'] = TP_seg + out[label]['FP_seg'] = FP_seg + out[label]['FN_seg'] = FN_seg + return out def confusion_matrix(self): diff --git a/voxaboxen/inference/inference.py b/voxaboxen/inference/inference.py index f223587..d140180 100644 --- a/voxaboxen/inference/inference.py +++ b/voxaboxen/inference/inference.py @@ -5,7 +5,7 @@ from voxaboxen.inference.params import parse_inference_args from voxaboxen.training.params import load_params -from voxaboxen.model.model import DetectionModel, DetectionModelStereo +from voxaboxen.model.model import DetectionModel, DetectionModelStereo, DetectionModelMultichannel from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table from voxaboxen.data.data import get_single_clip_data @@ -23,6 +23,8 @@ def inference(inference_args): # model if hasattr(args,'stereo') and args.stereo: model = DetectionModelStereo(args) + elif hasattr(args,'multichannel') and args.multichannel: + model = DetectionModelMultichannel(args) else: model = DetectionModel(args) model_checkpoint_fp = os.path.join(args.experiment_dir, "model.pt") diff --git a/voxaboxen/model/model.py b/voxaboxen/model/model.py index b53ee2f..9e3c4aa 100644 --- a/voxaboxen/model/model.py +++ b/voxaboxen/model/model.py @@ -80,7 +80,7 @@ def generate_features(self, x): Input x (Tensor): (batch, time) (time at 16000 Hz, audio_sr) Returns - features (Tensor): (batch, time) (time at 50 Hz, aves_sr) + features (Tensor): (batch, time, embedding_dim) (time at 50 Hz, aves_sr) """ expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) @@ -157,6 +157,79 @@ def forward(self, x): detection_probs = torch.sigmoid(detection_logits) return detection_probs, regression, class_logits + +class DetectionModelMultichannel(DetectionModel): + # supports >1 channel, but unlike Stereo model does not assume the order of channels matters. + def __init__(self, args, embedding_dim=768): + super().__init__(args, embedding_dim=embedding_dim) + self.n_classes = len(args.label_set) + + def forward(self, x): + """ + Input + x (Tensor): (batch, channels, time) (time at 16000 Hz, audio_sr) + Returns + detection_probs (Tensor): (batch, time,) (time at 50 Hz, aves_sr) + regression (Tensor): (batch, time,) (time at 50 Hz, aves_sr) + class_logits (Tensor): (batch, time, n_classes) (time at 50 Hz, aves_sr) + + """ + +# expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) + +# x = x-torch.mean(x,axis=-1,keepdim=True) +# feats=[] +# for i in range(x.size(1)): +# feats.append(self.encoder(x[:,i,:])) +# feats = sum(feats) + +# #aves may be off by 1 sample from expected +# pad = expected_dur_output - feats.size(1) +# if pad>0: +# feats = F.pad(feats, (0,0,0,pad), mode='reflect') + +# detection_logits, regression, class_logits = self.detection_head(feats) +# detection_probs = torch.sigmoid(detection_logits) + + + expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) + + x = x-torch.mean(x,axis=-1,keepdim=True) + + detection_logits = [] + regression = [] + class_logits = [] + + for i in range(x.size(1)): + feats = self.encoder(x[:,i,:]) + #aves may be off by 1 sample from expected + pad = expected_dur_output - feats.size(1) + if pad>0: + feats = F.pad(feats, (0,0,0,pad), mode='reflect') + + dl, rg, cl = self.detection_head(feats) + detection_logits.append(dl) + regression.append(rg) + class_logits.append(cl) + + detection_logits = torch.stack(detection_logits, dim=0) + regression = torch.stack(regression, dim=0) + class_logits = torch.stack(class_logits, dim=0) + + if hasattr(self.args, "segmentation_based") and self.args.segmentation_based: + detection_logits=torch.max(detection_logits,dim=0)[0] + regression=torch.max(regression,dim=0)[0] + class_logits=torch.max(class_logits,dim=0)[0] + else: + mask_based_on_logits = torch.eq(torch.max(detection_logits, dim=0,keepdim=True)[0], detection_logits) + + detection_logits = torch.sum(detection_logits*mask_based_on_logits, dim=0) + regression = torch.sum(regression*mask_based_on_logits,dim=0) + class_logits = torch.sum(class_logits* mask_based_on_logits.unsqueeze(-1).repeat(1,1,1,self.n_classes), dim=0) + + detection_probs = torch.sigmoid(detection_logits) + + return detection_probs, regression, class_logits def rms_and_mixup(X, d, r, y, train, args): @@ -168,26 +241,26 @@ def rms_and_mixup(X, d, r, y, train, args): if args.mixup and train: # TODO: For mixup, add in a check that there aren't extremely overlapping vocs - + batch_size = X.size(0) - + mask = torch.full((X.size(0),1,1), 0.5, device=X.device) mask = torch.bernoulli(mask) - + if len(X.size()) == 2: X_aug = torch.flip(X, (0,)) * mask[:,:,0] elif len(X.size()) == 3: X_aug = torch.flip(X, (0,)) * mask - + d_aug = torch.flip(d, (0,)) * mask[:,:,0] r_aug = torch.flip(r, (0,)) * mask[:,:,0] y_aug = torch.flip(y, (0,)) * mask - - X = (X + X_aug)[:batch_size//2,...] - d = torch.maximum(d, d_aug)[:batch_size//2,...] - r = torch.maximum(r, r_aug)[:batch_size//2,...] - y = torch.maximum(y, y_aug)[:batch_size//2,...] - + + X = (X + X_aug)#[:batch_size//2,...] + d = torch.maximum(d, d_aug)#[:batch_size//2,...] + r = torch.maximum(r, r_aug)#[:batch_size//2,...] + y = torch.maximum(y, y_aug)#[:batch_size//2,...] + if args.rms_norm: X = X * (1/2) diff --git a/voxaboxen/project/project_setup.py b/voxaboxen/project/project_setup.py index 09f9a8d..aa8b9ee 100644 --- a/voxaboxen/project/project_setup.py +++ b/voxaboxen/project/project_setup.py @@ -22,7 +22,7 @@ def project_setup(args): for annot_fp in annot_fps: if annot_fp != "None": selection_table = pd.read_csv(annot_fp, delimiter = '\t') - annots = list(selection_table['Annotation']) + annots = list(selection_table['Annotation'].astype(str)) all_annots.extend(annots) label_set = sorted(set(all_annots)) diff --git a/voxaboxen/training/params.py b/voxaboxen/training/params.py index 4c01610..806d63c 100644 --- a/voxaboxen/training/params.py +++ b/voxaboxen/training/params.py @@ -23,14 +23,15 @@ def parse_args(args,allow_unknown=False): # Model parser.add_argument('--sr', type=int, default=16000) parser.add_argument('--scale-factor', type=int, default = 320, help = "downscaling performed by aves") - parser.add_argument('--aves-model-weight-fp', type=str, default = "weights/aves-base-bio.torchaudio.pt") - parser.add_argument('--aves-config-fp', type=str, default = "weights/aves-base-bio.torchaudio.model_config.json") - parser.add_argument('--prediction-scale-factor', type=int, default = 1, help = "downsampling rate from aves sr to prediction sr") + parser.add_argument('--aves-config-fp', type=str, default = "weights/birdaves-biox-base.torchaudio.model_config.json") + parser.add_argument('--prediction-scale-factor', type=int, default = 1, help = "downsampling rate from aves sr to prediction sr. Deprecated.") parser.add_argument('--detection-threshold', type=float, default = 0.5, help = "output probability to count as positive detection") parser.add_argument('--rms-norm', action="store_true", help = "If true, apply rms normalization to each clip") parser.add_argument('--previous-checkpoint-fp', type=str, default=None, help="path to checkpoint of previously trained detection model") parser.add_argument('--aves-url', type=str, default = "https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.pt") - parser.add_argument('--stereo', action='store_true', help="If passed, will process stereo data as stereo") + parser.add_argument('--stereo', action='store_true', help="If passed, will process stereo data as stereo. order of channels matters") + parser.add_argument('--multichannel', action='store_true', help="If passed, will encode each audio channel seperately, then add together the encoded audio before final layer") + parser.add_argument('--segmentation-based', action='store_true', help="If passed, will make predictions based on frame-wise segmentations rather than box starts") # Training parser.add_argument('--batch-size', type=int, default=32) @@ -60,6 +61,8 @@ def parse_args(args,allow_unknown=False): parser.add_argument('--soft-nms-sigma', type = float, default = 0.5) parser.add_argument('--soft-nms-thresh', type = float, default = 0.001) parser.add_argument('--nms-thresh', type = float, default = 0.5) + parser.add_argument('--delete-short-dur-sec', type=float, default=0.1, help="if using segmentation based model, delete vox shorter than this as a post-processing step") + parser.add_argument('--fill-holes-dur-sec', type=float, default=0.1, help="if using segmentation based model, fill holes shorter than this as a post-processing step") if allow_unknown: args, remaining = parser.parse_known_args(args) @@ -118,4 +121,8 @@ def load_params(fp): def check_config(args): assert args.end_mask_perc < 0.25, "Masking above 25% of each end during training will interfere with inference" - assert ((args.clip_duration * args.sr)/(4*args.scale_factor)).is_integer(), "Must pick clip duration to ensure no rounding errors during inference" \ No newline at end of file + assert ((args.clip_duration * args.sr)/(4*args.scale_factor)).is_integer(), "Must pick clip duration to ensure no rounding errors during inference" + if args.segmentation_based and (args.rho!=1): + import warnings + warnings.warn("when using segmentation-based framework, recommend setting args.rho=1") + \ No newline at end of file diff --git a/voxaboxen/training/train.py b/voxaboxen/training/train.py index 5bb7987..58abb97 100644 --- a/voxaboxen/training/train.py +++ b/voxaboxen/training/train.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch.nn.functional as F +import torchvision import tqdm from functools import partial import os @@ -27,7 +28,7 @@ def train(model, args): cp = torch.load(args.previous_checkpoint_fp) model.load_state_dict(cp["model_state_dict"]) - detection_loss_fn = modified_focal_loss + detection_loss_fn = get_detection_loss_fn(args) reg_loss_fn = get_reg_loss_fn(args) class_loss_fn = get_class_loss_fn(args) @@ -291,26 +292,52 @@ def masked_classification_loss(class_logits, y, d, class_weights = None): class_loss = class_loss / n_pos return class_loss + +def segmentation_loss(class_logits, y, d, class_weights=None): + # class_logits (Tensor): [batch, time,n_classes] + # y (Tensor): [batch, time,n_classes] + # d (Tensor) : [batch, time,], float tensor + # class_weight : [n_classes,], float tensor -def get_class_loss_fn(args): - dataloader_temp = get_train_dataloader(args, random_seed_shift = 0) - class_proportions = dataloader_temp.dataset.get_class_proportions() - class_weights = 1. / (class_proportions + 1e-6) - - class_weights = (1. / (np.mean(class_weights) + 1e-6)) * class_weights # normalize so average weight = 1 - - print(f"Using class weights {class_weights}") + default_focal_loss = torchvision.ops.sigmoid_focal_loss(class_logits, y, reduction='mean') + return default_focal_loss + - class_weights = torch.Tensor(class_weights).to(device) - return partial(masked_classification_loss, class_weights = class_weights) +def get_class_loss_fn(args): + if hasattr(args,"segmentation_based") and args.segmentation_based: + return segmentation_loss + else: + dataloader_temp = get_train_dataloader(args, random_seed_shift = 0) + class_proportions = dataloader_temp.dataset.get_class_proportions() + class_weights = 1. / (class_proportions + 1e-6) + + class_weights = (1. / (np.mean(class_weights) + 1e-6)) * class_weights # normalize so average weight = 1 + + print(f"Using class weights {class_weights}") + + class_weights = torch.Tensor(class_weights).to(device) + return partial(masked_classification_loss, class_weights = class_weights) def get_reg_loss_fn(args): - dataloader_temp = get_train_dataloader(args, random_seed_shift = 0) - class_proportions = dataloader_temp.dataset.get_class_proportions() - class_weights = 1. / (class_proportions + 1e-6) - - class_weights = (1. / (np.mean(class_weights) + 1e-6)) * class_weights # normalize so average weight = 1 - - class_weights = torch.Tensor(class_weights).to(device) - return partial(masked_reg_loss, class_weights = class_weights) - \ No newline at end of file + if hasattr(args,"segmentation_based") and args.segmentation_based: + def zrl(regression, r, d, y, class_weights = None): + return torch.tensor(0.) + return zrl + else: + dataloader_temp = get_train_dataloader(args, random_seed_shift = 0) + class_proportions = dataloader_temp.dataset.get_class_proportions() + class_weights = 1. / (class_proportions + 1e-6) + + class_weights = (1. / (np.mean(class_weights) + 1e-6)) * class_weights # normalize so average weight = 1 + + class_weights = torch.Tensor(class_weights).to(device) + return partial(masked_reg_loss, class_weights = class_weights) + +def get_detection_loss_fn(args): + if hasattr(args,"segmentation_based") and args.segmentation_based: + def zdl(pred, gt, pos_loss_weight = 1): + return torch.tensor(0.) + return zdl + else: + return modified_focal_loss + \ No newline at end of file diff --git a/voxaboxen/training/train_model.py b/voxaboxen/training/train_model.py index ee60a62..b59b128 100644 --- a/voxaboxen/training/train_model.py +++ b/voxaboxen/training/train_model.py @@ -1,5 +1,5 @@ from voxaboxen.data.data import get_test_dataloader -from voxaboxen.model.model import DetectionModel, DetectionModelStereo +from voxaboxen.model.model import DetectionModel, DetectionModelStereo, DetectionModelMultichannel from voxaboxen.training.train import train from voxaboxen.training.params import parse_args, set_seed, save_params from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table, get_metrics, summarize_metrics, predict_and_generate_manifest, evaluate_based_on_manifest @@ -27,6 +27,8 @@ def train_model(args): save_params(args) if hasattr(args,'stereo') and args.stereo: model = DetectionModelStereo(args) + elif hasattr(args,'multichannel') and args.multichannel: + model = DetectionModelMultichannel(args) else: model = DetectionModel(args) diff --git a/weights/birdaves-biox-base.torchaudio.model_config.json b/weights/birdaves-biox-base.torchaudio.model_config.json new file mode 100644 index 0000000..17c3154 --- /dev/null +++ b/weights/birdaves-biox-base.torchaudio.model_config.json @@ -0,0 +1,53 @@ +{ + "extractor_mode": "group_norm", + "extractor_conv_layer_config": [ + [ + 512, + 10, + 5 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 3, + 2 + ], + [ + 512, + 2, + 2 + ], + [ + 512, + 2, + 2 + ] + ], + "extractor_conv_bias": false, + "encoder_embed_dim": 768, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 12, + "encoder_num_heads": 12, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 3072, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.1, + "encoder_layer_norm_first": false, + "encoder_layer_drop": 0.05 +} \ No newline at end of file