diff --git a/README.md b/README.md index 7c69d9f..c8be953 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ For example, say you annotate your audio with the labels Red-eyed Vireo `REVI`, Here are some additional options that can be applied during training: - Flag `--stereo` accepts stereo audio. Order of channels matters; used for e.g. speaker diarization. -- Flag `--multichannel` accepts audio with >1 channel. Order of channels does not matter. +- Flag `--bidirectional` predicts the ends of events in addition to the beginning, matches starts and ends based on IoU. May improve box regression. - Flag `--segmentation-based` switches to a frame-based approach. If used, we recommend putting `--rho=1`. - Flag `--mixup` applies mixup augmentation. diff --git a/voxaboxen/data/data.py b/voxaboxen/data/data.py index 5f29e66..9d53bf1 100644 --- a/voxaboxen/data/data.py +++ b/voxaboxen/data/data.py @@ -1,9 +1,7 @@ -import os import math import numpy as np import pandas as pd import librosa -import warnings from numpy.random import default_rng from intervaltree import IntervalTree @@ -16,16 +14,16 @@ def normalize_sig_np(sig, eps=1e-8): sig = sig / (np.max(np.abs(sig))+eps) return sig - + def crop_and_pad(wav, sr, dur_sec): # crops and pads waveform to be the expected number of samples; used after resampling to ensure proper size target_dur_samples = int(sr * dur_sec) wav = wav[..., :target_dur_samples] - + pad = target_dur_samples - wav.size(-1) if pad > 0: wav = F.pad(wav, (0,pad)) #padding starts from last dims - + return wav class DetectionDataset(Dataset): @@ -47,7 +45,7 @@ def __init__(self, info_df, train, args, random_seed_shift = 0): if self.amp_aug: self.amp_aug_low_r = args.amp_aug_low_r self.amp_aug_high_r = args.amp_aug_high_r - assert (self.amp_aug_low_r >= 0) #and (self.amp_aug_high_r <= 1) and + assert (self.amp_aug_low_r >= 0) #and (self.amp_aug_high_r <= 1) and assert (self.amp_aug_low_r <= self.amp_aug_high_r) self.scale_factor = args.scale_factor @@ -61,14 +59,14 @@ def __init__(self, info_df, train, args, random_seed_shift = 0): self.mono = False else: self.mono = True - + if self.train: self.omit_empty_clip_prob = args.omit_empty_clip_prob self.clip_start_offset = self.rng.integers(0, np.floor(self.clip_hop*self.sr)) / self.sr else: self.omit_empty_clip_prob = 0 self.clip_start_offset = 0 - + self.args=args # make metadata self.make_metadata() @@ -89,15 +87,15 @@ def process_selection_table(self, selection_table_fp): start = row['Begin Time (s)'] end = row['End Time (s)'] label = row['Annotation'] - + if end<=start: continue - + if label in self.label_mapping: label = self.label_mapping[label] else: continue - + if label == self.unknown_label: label_idx = -1 else: @@ -113,7 +111,7 @@ def make_metadata(self): for ii, row in self.info_df.iterrows(): fn = row['fn'] audio_fp = row['audio_fp'] - + duration = librosa.get_duration(path=audio_fp) selection_table_fp = row['selection_table_fp'] @@ -144,10 +142,10 @@ def get_pos_intervals(self, fn, start, end): intervals = [(max(iv.begin, start)-start, min(iv.end, end)-start, iv.data) for iv in intervals] return intervals - + def get_class_proportions(self): counts = np.zeros((self.n_classes,)) - + for k in self.selection_table_dict: st = self.selection_table_dict[k] for interval in st: @@ -156,98 +154,104 @@ def get_class_proportions(self): continue else: counts[annot] += 1 - + total_count = np.sum(counts) proportions = counts / total_count - + return proportions - def get_annotation(self, pos_intervals, audio): - raw_seq_len = audio.shape[-1] seq_len = int(math.ceil(raw_seq_len / self.scale_factor_raw_to_prediction)) - regression_anno = np.zeros((seq_len,)) - class_anno = np.zeros((seq_len, self.n_classes)) - anno_sr = int(self.sr // self.scale_factor_raw_to_prediction) - + + regression_annos = np.zeros((seq_len,)) + class_annos = np.zeros((seq_len, self.n_classes)) anchor_annos = [np.zeros(seq_len,)] + rev_regression_annos = np.zeros((seq_len,)) + rev_class_annos = np.zeros((seq_len, self.n_classes)) + rev_anchor_annos = [np.zeros(seq_len,)] for iv in pos_intervals: start, end, class_idx = iv dur = end-start - + dur_samples = np.ceil(dur * anno_sr) + start_idx = int(math.floor(start*anno_sr)) start_idx = max(min(start_idx, seq_len-1), 0) - dur_samples = int(np.ceil(dur * anno_sr)) + + end_idx = int(math.ceil(end*anno_sr)) + end_idx = max(min(end_idx, seq_len-1), 0) + dur_samples = int(np.ceil(dur * anno_sr)) anchor_anno = get_anchor_anno(start_idx, dur_samples, seq_len) anchor_annos.append(anchor_anno) - regression_anno[start_idx] = dur + regression_annos[start_idx] = dur + + rev_anchor_anno = get_anchor_anno(end_idx, dur_samples, seq_len) + rev_anchor_annos.append(rev_anchor_anno) + rev_regression_annos[end_idx] = dur if hasattr(self.args,"segmentation_based") and self.args.segmentation_based: if class_idx == -1: pass else: - class_anno[start_idx:start_idx+dur_samples,class_idx]=1. + class_annos[start_idx:start_idx+dur_samples,class_idx]=1. else: if class_idx != -1: - class_anno[start_idx, class_idx] = 1. + class_annos[start_idx, class_idx] = 1. + rev_class_annos[end_idx, class_idx] = 1. else: - class_anno[start_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty - + class_annos[start_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty + rev_class_annos[end_idx, :] = 1./self.n_classes # if unknown, enforce uncertainty + + anchor_annos = np.stack(anchor_annos) anchor_annos = np.amax(anchor_annos, axis = 0) - - return anchor_annos, regression_anno, class_anno # shapes [time_steps, ], [time_steps, ], [time_steps, n_classes] + rev_anchor_annos = np.stack(rev_anchor_annos) + rev_anchor_annos = np.amax(rev_anchor_annos, axis = 0) + # shapes [time_steps, ], [time_steps, ], [time_steps, n_classes] (times two) + return anchor_annos, regression_annos, class_annos, rev_anchor_annos, rev_regression_annos, rev_class_annos def __getitem__(self, index): fn, audio_fp, start, end = self.metadata[index] - - audio, file_sr = librosa.load(audio_fp, sr=None, offset=start, duration=self.clip_duration, mono=self.mono) + + audio, file_sr = librosa.load(audio_fp, sr=None, offset=start, duration=self.clip_duration, mono=self.mono) audio = torch.from_numpy(audio) - + audio = audio-torch.mean(audio, -1, keepdim=True) if self.amp_aug and self.train: audio = self.augment_amplitude(audio) if file_sr != self.sr: - audio = torchaudio.functional.resample(audio, file_sr, self.sr) - + audio = torchaudio.functional.resample(audio, file_sr, self.sr) + audio = crop_and_pad(audio, self.sr, self.clip_duration) - + pos_intervals = self.get_pos_intervals(fn, start, end) - anchor_anno, regression_anno, class_anno = self.get_annotation(pos_intervals, audio) + anchor_anno, regression_anno, class_anno, rev_anchor_anno, rev_regression_anno, rev_class_anno = self.get_annotation(pos_intervals, audio) - return audio, torch.from_numpy(anchor_anno), torch.from_numpy(regression_anno), torch.from_numpy(class_anno) + return audio, torch.from_numpy(anchor_anno), torch.from_numpy(regression_anno), torch.from_numpy(class_anno), torch.from_numpy(rev_anchor_anno), torch.from_numpy(rev_regression_anno), torch.from_numpy(rev_class_anno) def __len__(self): return len(self.metadata) - - + + def get_train_dataloader(args, random_seed_shift = 0): train_info_fp = args.train_info_fp train_info_df = pd.read_csv(train_info_fp) - + train_dataset = DetectionDataset(train_info_df, True, args, random_seed_shift = random_seed_shift) - - # if args.mixup: - # effective_batch_size = args.batch_size*2 # double batch size because half will be discarded before being passed to model - # else: - # effective_batch_size = args.batch_size - - train_dataloader = DataLoader(train_dataset, - batch_size=args.batch_size, #effective_batch_size, + batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, - pin_memory=True, + pin_memory=True, drop_last = True) - + return train_dataloader - + class SingleClipDataset(Dataset): def __init__(self, audio_fp, clip_hop, args, annot_fp = None): # waveform (samples,) @@ -265,26 +269,26 @@ def __init__(self, audio_fp, clip_hop, args, annot_fp = None): self.mono = False else: self.mono = True - + def __len__(self): return self.num_clips def __getitem__(self, idx): """ Map int idx to dict of torch tensors """ start = idx * self.clip_hop - + audio, file_sr = librosa.load(self.audio_fp, sr=None, offset=start, duration=self.clip_duration, mono=self.mono) audio = torch.from_numpy(audio) - - + + audio = audio-torch.mean(audio, -1, keepdim=True) if file_sr != self.sr: - audio = torchaudio.functional.resample(audio, file_sr, self.sr) - + audio = torchaudio.functional.resample(audio, file_sr, self.sr) + audio = crop_and_pad(audio, self.sr, self.clip_duration) - + return audio - + def get_single_clip_data(audio_fp, clip_hop, args, annot_fp = None): return DataLoader( SingleClipDataset(audio_fp, clip_hop, args, annot_fp = annot_fp), @@ -296,33 +300,33 @@ def get_single_clip_data(audio_fp, clip_hop, args, annot_fp = None): ) def get_val_dataloader(args): - val_info_fp = args.val_info_fp + val_info_fp = args.val_info_fp val_info_df = pd.read_csv(val_info_fp) - + val_dataloaders = {} - + for i in range(len(val_info_df)): fn = val_info_df.iloc[i]['fn'] audio_fp = val_info_df.iloc[i]['audio_fp'] annot_fp = val_info_df.iloc[i]['selection_table_fp'] val_dataloaders[fn] = get_single_clip_data(audio_fp, args.clip_duration/2, args, annot_fp = annot_fp) - + return val_dataloaders def get_test_dataloader(args): test_info_fp = args.test_info_fp test_info_df = pd.read_csv(test_info_fp) - + test_dataloaders = {} - + for i in range(len(test_info_df)): fn = test_info_df.iloc[i]['fn'] audio_fp = test_info_df.iloc[i]['audio_fp'] annot_fp = test_info_df.iloc[i]['selection_table_fp'] test_dataloaders[fn] = get_single_clip_data(audio_fp, args.clip_duration/2, args, annot_fp = annot_fp) - + return test_dataloaders - + def get_anchor_anno(start_idx, dur_samples, seq_len): # start times plus gaussian blur # std setting follows CornerNet, where adaptive standard deviation is set to 1/3 image radius @@ -331,4 +335,5 @@ def get_anchor_anno(start_idx, dur_samples, seq_len): x = x / (2 * std**2) x = np.exp(-x) return x - \ No newline at end of file + + diff --git a/voxaboxen/evaluation/evaluation.py b/voxaboxen/evaluation/evaluation.py index f623652..e175f4c 100644 --- a/voxaboxen/evaluation/evaluation.py +++ b/voxaboxen/evaluation/evaluation.py @@ -12,10 +12,11 @@ from voxaboxen.evaluation.raven_utils import Clip from voxaboxen.model.model import rms_and_mixup from voxaboxen.evaluation.nms import nms, soft_nms +plt.switch_backend('agg') device = "cuda" if torch.cuda.is_available() else "cpu" -def pred2bbox(detection_peaks, detection_probs, durations, class_idxs, class_probs, pred_sr): +def pred2bbox(detection_peaks, detection_probs, durations, class_idxs, class_probs, pred_sr, is_rev): ''' detection_peaks, detection_probs, durations, class_idxs, class_probs : shape=(num_frames,) @@ -29,23 +30,27 @@ def pred2bbox(detection_peaks, detection_probs, durations, class_idxs, class_pro detection_probs_sub = [] class_idxs_sub = [] class_probs_sub = [] - + for i in range(len(detection_peaks)): duration = durations[i] - start = detection_peaks[i] - if duration <= 0: continue - - bbox = [start, start+duration] + + if is_rev: + end = detection_peaks[i] + bbox = [end-duration, end] + else: + start = detection_peaks[i] + bbox = [start, start+duration] + bboxes.append(bbox) - + detection_probs_sub.append(detection_probs[i]) class_idxs_sub.append(class_idxs[i]) class_probs_sub.append(class_probs[i]) - + return np.array(bboxes), np.array(detection_probs_sub), np.array(class_idxs_sub), np.array(class_probs_sub) - + def bbox2raven(bboxes, class_idxs, label_set, detection_probs, class_probs, unknown_label): ''' output bounding boxes to a selection table @@ -55,18 +60,18 @@ def bbox2raven(bboxes, class_idxs, label_set, detection_probs, class_probs, unkn bboxes: numpy array shape=(num_bboxes, 2) - + class_idxs: numpy array shape=(num_bboxes,) label_set: list - + detection_probs: numpy array shape =(num_bboxes,) - + class_probs: numpy array shape = (num_bboxes,) - + unknown_label: str ''' @@ -74,14 +79,14 @@ def bbox2raven(bboxes, class_idxs, label_set, detection_probs, class_probs, unkn return [['Begin Time (s)', 'End Time (s)', 'Annotation', 'Detection Prob', 'Class Prob']] columns = ['Begin Time (s)', 'End Time (s)', 'Annotation', 'Detection Prob', 'Class Prob'] - - + + def label_idx_to_label(i): if i==-1: return unknown_label else: return label_set[i] - + out_data = [[bbox[0], bbox[1], label_idx_to_label(int(c)), dp, cp] for bbox, c, dp, cp in zip(bboxes, class_idxs, detection_probs, class_probs)] out_data = sorted(out_data, key=lambda x: x[:2]) @@ -110,77 +115,101 @@ def generate_predictions(model, single_clip_dataloader, args, verbose = True): model = model.to(device) model.eval() - + all_detections = [] all_regressions = [] - all_classifications = [] - + all_classifs = [] + all_rev_detections = [] + all_rev_regressions = [] + all_rev_classifs = [] + if verbose: iterator = tqdm.tqdm(enumerate(single_clip_dataloader), total=len(single_clip_dataloader)) else: iterator = enumerate(single_clip_dataloader) - + with torch.no_grad(): for i, X in iterator: X = X.to(device = device, dtype = torch.float) X, _, _, _ = rms_and_mixup(X, None, None, None, False, args) - - detection, regression, classification = model(X) + + model_outputs = model(X) + assert isinstance(model_outputs, tuple) + all_detections.append(model_outputs[0]) + all_regressions.append(model_outputs[1]) if hasattr(args, "segmentation_based") and args.segmentation_based: - classification=torch.nn.functional.sigmoid(classification) + classification=torch.nn.functional.sigmoid(model_outputs[2]) else: - classification=torch.nn.functional.softmax(classification, dim=-1) - - all_detections.append(detection) - all_regressions.append(regression) - all_classifications.append(classification) - + classification=torch.nn.functional.softmax(model_outputs[2], dim=-1) + all_classifs.append(classification) + if model.is_bidirectional: + assert all(x is not None for x in model_outputs) + all_rev_detections.append(model_outputs[3]) + all_rev_regressions.append(model_outputs[4]) + all_rev_classifs.append(model_outputs[5].softmax(-1)) # segmentation-based is not used when bidirectional + else: + assert all(x is None for x in model_outputs[3:]) + + # if args.is_test and i==15: + # break + all_detections = torch.cat(all_detections) all_regressions = torch.cat(all_regressions) - all_classifications = torch.cat(all_classifications) + all_classifs = torch.cat(all_classifs) + if model.is_bidirectional: + all_rev_detections = torch.cat(all_rev_detections) + all_rev_regressions = torch.cat(all_rev_regressions) + all_rev_classifs = torch.cat(all_rev_classifs) - # we use half overlapping windows, need to throw away boundary predictions - # See get_val_dataloader and get_test_dataloader in data.py - - ######## Todo: Need better checking that preds are the correct dur + ######## Todo: Need better checking that preds are the correct dur assert all_detections.size(dim=1) % 2 == 0 first_quarter_window_dur_samples=all_detections.size(dim=1)//4 last_quarter_window_dur_samples=(all_detections.size(dim=1)//2)-first_quarter_window_dur_samples - - # assemble detections - beginning_bit = all_detections[0,:first_quarter_window_dur_samples] - end_bit = all_detections[-1,-last_quarter_window_dur_samples:] - detections_clipped = all_detections[:,first_quarter_window_dur_samples:-last_quarter_window_dur_samples] - all_detections = torch.reshape(detections_clipped, (-1,)) - all_detections = torch.cat([beginning_bit, all_detections, end_bit]) - - # assemble regressions - beginning_bit = all_regressions[0,:first_quarter_window_dur_samples] - end_bit = all_regressions[-1,-last_quarter_window_dur_samples:] - regressions_clipped = all_regressions[:,first_quarter_window_dur_samples:-last_quarter_window_dur_samples] - all_regressions = torch.reshape(regressions_clipped, (-1,)) - all_regressions = torch.cat([beginning_bit, all_regressions, end_bit]) - - # assemble classifications - beginning_bit = all_classifications[0,:first_quarter_window_dur_samples, :] - end_bit = all_classifications[-1,-last_quarter_window_dur_samples:, :] - classifications_clipped = all_classifications[:,first_quarter_window_dur_samples:-last_quarter_window_dur_samples,:] - all_classifications = torch.reshape(classifications_clipped, (-1, classifications_clipped.size(-1))) - all_classifications = torch.cat([beginning_bit, all_classifications, end_bit]) - - return all_detections.detach().cpu().numpy(), all_regressions.detach().cpu().numpy(), all_classifications.detach().cpu().numpy() + + def assemble(d, r, c): + """We use half overlapping windows, need to throw away boundary predictions. + See get_val_dataloader and get_test_dataloader in data.py""" + # assemble detections + beginning_d_bit = d[0,:first_quarter_window_dur_samples] + end_d_bit = d[-1,-last_quarter_window_dur_samples:] + d_clipped = d[:,first_quarter_window_dur_samples:-last_quarter_window_dur_samples] + middle_d_bit = torch.reshape(d_clipped, (-1,)) + assembled_d = torch.cat([beginning_d_bit, middle_d_bit, end_d_bit]) + + # assemble regressions + beginning_r_bit = r[0,:first_quarter_window_dur_samples] + end_r_bit = r[-1,-last_quarter_window_dur_samples:] + r_clipped = r[:,first_quarter_window_dur_samples:-last_quarter_window_dur_samples] + middle_r_bit = torch.reshape(r_clipped, (-1,)) + assembled_r = torch.cat([beginning_r_bit, middle_r_bit, end_r_bit]) + + # assemble classifs + beginning_c_bit = c[0,:first_quarter_window_dur_samples, :] + end_c_bit = c[-1,-last_quarter_window_dur_samples:, :] + c_clipped = c[:,first_quarter_window_dur_samples:-last_quarter_window_dur_samples,:] + middle_c_bit = torch.reshape(c_clipped, (-1, c_clipped.size(-1))) + assembled_c = torch.cat([beginning_c_bit, middle_c_bit, end_c_bit]) + return assembled_d.detach().cpu().numpy(), assembled_r.detach().cpu().numpy(), assembled_c.detach().cpu().numpy(), + + assembled_dets, assembled_regs, assembled_classifs = assemble(all_detections, all_regressions, all_classifs) + if model.is_bidirectional: + assembled_rev_dets, assembled_rev_regs, assembled_rev_classifs = assemble(all_rev_detections, all_rev_regressions, all_rev_classifs) + else: + assembled_rev_dets = assembled_rev_regs = assembled_rev_classifs = None + + return assembled_dets, assembled_regs, assembled_classifs, assembled_rev_dets, assembled_rev_regs, assembled_rev_classifs def generate_features(model, single_clip_dataloader, args, verbose = True): model = model.to(device) model.eval() - + all_features = [] - + if verbose: iterator = tqdm.tqdm(enumerate(single_clip_dataloader), total=len(single_clip_dataloader)) else: iterator = enumerate(single_clip_dataloader) - + with torch.no_grad(): for i, X in iterator: X = X.to(device = device, dtype = torch.float) @@ -188,19 +217,19 @@ def generate_features(model, single_clip_dataloader, args, verbose = True): features = model.generate_features(X) all_features.append(features) all_features = torch.cat(all_features) - - ######## Need better checking that features are the correct dur + + ######## Need better checking that features are the correct dur assert all_features.size(dim=1) % 2 == 0 first_quarter_window_dur_samples=all_features.size(dim=1)//4 last_quarter_window_dur_samples=(all_features.size(dim=1)//2)-first_quarter_window_dur_samples - + # assemble features beginning_bit = all_features[0,:first_quarter_window_dur_samples,:] end_bit = all_features[-1,-last_quarter_window_dur_samples:,:] features_clipped = all_features[:,first_quarter_window_dur_samples:-last_quarter_window_dur_samples,:] all_features = torch.reshape(features_clipped, (-1, features_clipped.size(-1))) all_features = torch.cat([beginning_bit, all_features, end_bit]) - + return all_features.detach().cpu().numpy() def fill_holes(m, max_hole): @@ -236,16 +265,22 @@ def delete_short(m, min_pos): m[clip[0]:clip[1]] = True return m - -def export_to_selection_table(detections, regressions, classifications, fn, args, verbose=True, target_dir=None, detection_threshold = 0.5, classification_threshold = 0): - + +def export_to_selection_table(detections, regressions, classifications, fn, args, is_bck, verbose=True, target_dir=None, detection_threshold = 0.5, classification_threshold = 0): + + if hasattr(args, "bidirectional") and args.bidirectional: + if is_bck: + fn += '-bck' + else: + fn += '-fwd' + if target_dir is None: target_dir = args.experiment_output_dir if hasattr(args, "segmentation_based") and args.segmentation_based: pred_sr = args.sr // (args.scale_factor * args.prediction_scale_factor) bboxes = [] - detection_probs = [] + det_probs = [] class_idxs = [] class_probs = [] for c in range(np.shape(classifications)[1]): @@ -267,19 +302,19 @@ def export_to_selection_table(detections, regressions, classifications, fn, args bbox = [start/pred_sr,end/pred_sr] bboxes.append(bbox) - detection_probs.append(classifications_sub[start:end].mean()) + det_probs.append(classifications_sub[start:end].mean()) class_idxs.append(c) class_probs.append(classifications_sub[start:end].mean()) bboxes=np.array(bboxes) - detection_probs=np.array(detection_probs) + det_probs=np.array(det_probs) class_idxs=np.array(class_idxs) class_probs=np.array(class_probs) else: ## peaks detection_peaks, properties = find_peaks(detections, height = detection_threshold, distance=args.peak_distance) - detection_probs = properties['peak_heights'] + det_probs = properties['peak_heights'] ## regressions and classifications durations = [] @@ -305,50 +340,50 @@ def export_to_selection_table(detections, regressions, classifications, fn, args pred_sr = args.sr // (args.scale_factor * args.prediction_scale_factor) - bboxes, detection_probs, class_idxs, class_probs = pred2bbox(detection_peaks, detection_probs, durations, class_idxs, class_probs, pred_sr) - + bboxes, det_probs, class_idxs, class_probs = pred2bbox(detection_peaks, det_probs, durations, class_idxs, class_probs, pred_sr, is_bck) + if args.nms == "soft_nms": - bboxes, detection_probs, class_idxs, class_probs = soft_nms(bboxes, detection_probs, class_idxs, class_probs, sigma = args.soft_nms_sigma, thresh = args.detection_threshold) + bboxes, det_probs, class_idxs, class_probs = soft_nms(bboxes, det_probs, class_idxs, class_probs, sigma=args.soft_nms_sigma, thresh=detection_threshold) elif args.nms == "nms": - bboxes, detection_probs, class_idxs, class_probs = nms(bboxes, detection_probs, class_idxs, class_probs, iou_thresh = args.nms_thresh) - + bboxes, det_probs, class_idxs, class_probs = nms(bboxes, det_probs, class_idxs, class_probs, iou_thresh=args.nms_thresh) + if verbose: - print(f"Found {len(detection_probs)} boxes") - + print(f"Found {len(det_probs)} boxes") + target_fp = os.path.join(target_dir, f"peaks_pred_{fn}.txt") - - st = bbox2raven(bboxes, class_idxs, args.label_set, detection_probs, class_probs, args.unknown_label) + + st = bbox2raven(bboxes, class_idxs, args.label_set, det_probs, class_probs, args.unknown_label) write_tsv(target_fp, st) - + return target_fp + def get_metrics(predictions_fp, annotations_fp, args, iou, class_threshold, duration): c = Clip(label_set=args.label_set, unknown_label=args.unknown_label) c.duration = duration - c.load_predictions(predictions_fp) c.threshold_class_predictions(class_threshold) c.load_annotations(annotations_fp, label_mapping = args.label_mapping) - + metrics = {} - + c.compute_matching(IoU_minimum = iou) metrics = c.evaluate() - + return metrics def get_confusion_matrix(predictions_fp, annotations_fp, args, iou, class_threshold): c = Clip(label_set=args.label_set, unknown_label=args.unknown_label) - + c.load_predictions(predictions_fp) c.threshold_class_predictions(class_threshold) c.load_annotations(annotations_fp, label_mapping = args.label_mapping) - + confusion_matrix = {} - + c.compute_matching(IoU_minimum = iou) confusion_matrix, confusion_matrix_labels = c.confusion_matrix() - + return confusion_matrix, confusion_matrix_labels def summarize_metrics(metrics): @@ -360,7 +395,7 @@ def summarize_metrics(metrics): class_labels = sorted(metrics[fps[0]].keys()) overall = { l: {'TP' : 0, 'FP' : 0, 'FN' : 0, 'TP_seg' : 0, 'FP_seg' : 0, 'FN_seg' : 0} for l in class_labels} - + for fp in fps: for l in class_labels: counts = metrics[fp][l] @@ -421,9 +456,8 @@ def macro_metrics(summary): # summary (dict) : {class_label: {'f1' : float, 'precision' : float, 'recall' : float, 'f1_seg' : float, 'precision_seg' : float, 'recall_seg' : float, 'TP': int, 'FP' : int, 'FN' : int, TP_seg': int, 'FP_seg' : int, 'FN_seg' : int}} metrics = ['f1', 'precision', 'recall', 'f1_seg', 'precision_seg', 'recall_seg'] - macro = {} - + for metric in metrics: e = [] @@ -431,12 +465,12 @@ def macro_metrics(summary): m = summary[l][metric] e.append(m) macro[metric] = float(np.mean(e)) - + return macro def plot_confusion_matrix(data, label_names, target_dir, name=""): - - fig = plt.figure(num=None, figsize=(12, 8), dpi=80, facecolor='w', edgecolor='k') + + fig = plt.figure(num=None, figsize=(16, 12), dpi=80, facecolor='w', edgecolor='k') plt.clf() ax = fig.add_subplot(111) ax.set_aspect(1) @@ -445,11 +479,11 @@ def plot_confusion_matrix(data, label_names, target_dir, name=""): ax.set_yticks([i + 0.5 for i in range(len(label_names))]) ax.set_yticklabels(label_names, rotation = 0) ax.set_xticks([i + 0.5 for i in range(len(label_names))]) - ax.set_xticklabels(label_names, rotation = -15) + ax.set_xticklabels(label_names, rotation = -90) ax.set_ylabel('Prediction') ax.set_xlabel('Annotation') plt.title(name) - + plt.savefig(os.path.join(target_dir, f"{name}_confusion_matrix.svg")) plt.close() @@ -458,69 +492,124 @@ def summarize_confusion_matrix(confusion_matrix, confusion_matrix_labels): # confusion_matrix (dict) : {fp : fp_cm} # where # fp_cm : numpy array - + fps = sorted(confusion_matrix.keys()) l = len(confusion_matrix_labels) - + overall = np.zeros((l, l)) - + for fp in fps: overall += confusion_matrix[fp] - + return overall, confusion_matrix_labels def predict_and_generate_manifest(model, dataloader_dict, args, verbose = True): fns = [] - predictions_fps = [] + fwd_predictions_fps = [] + bck_predictions_fps = [] annotations_fps = [] durations = [] for fn in dataloader_dict: - detections, regressions, classifications = generate_predictions(model, dataloader_dict[fn], args, verbose=verbose) - - predictions_fp = export_to_selection_table(detections, regressions, classifications, fn, args, verbose = verbose, detection_threshold = args.detection_threshold) - + fwd_detections, fwd_regressions, fwd_classifications, bck_detections, bck_regressions, bck_classifications = generate_predictions(model, dataloader_dict[fn], args, verbose=verbose) + + fwd_predictions_fp = export_to_selection_table(fwd_detections, fwd_regressions, fwd_classifications, fn, args, is_bck=False, verbose=verbose, detection_threshold=args.detection_threshold) + if model.is_bidirectional: + assert all(x is not None for x in [bck_detections, bck_classifications, bck_regressions]) + bck_predictions_fp = export_to_selection_table(bck_detections, bck_regressions, bck_classifications, fn, args, is_bck=True, verbose=verbose, detection_threshold=args.detection_threshold) + else: + assert all(x is None for x in [bck_detections, bck_classifications, bck_regressions]) + bck_predictions_fp = None annotations_fp = dataloader_dict[fn].dataset.annot_fp - + fns.append(fn) - predictions_fps.append(predictions_fp) + fwd_predictions_fps.append(fwd_predictions_fp) + bck_predictions_fps.append(bck_predictions_fp) annotations_fps.append(annotations_fp) - durations.append(np.shape(detections)[0]*args.scale_factor/args.sr) - - manifest = pd.DataFrame({'filename' : fns, 'predictions_fp' : predictions_fps, 'annotations_fp' : annotations_fps, 'duration_sec' : durations}) + durations.append(np.shape(fwd_detections)[0]*args.scale_factor/args.sr) + + manifest = pd.DataFrame({'filename' : fns, 'fwd_predictions_fp' : fwd_predictions_fps, 'bck_predictions_fp' : bck_predictions_fps, 'annotations_fp' : annotations_fps, 'duration_sec' : durations}) return manifest - -def evaluate_based_on_manifest(manifest, args, output_dir = None, iou = 0.5, class_threshold = 0.0): - - metrics = {} - confusion_matrix = {} - + +def evaluate_based_on_manifest(manifest, args, output_dir, iou, class_threshold, comb_discard_threshold): + pred_types = ('fwd', 'bck', 'comb', 'match') if args.bidirectional else ('fwd',) + metrics = {p:{} for p in pred_types} + conf_mats = {p:{} for p in pred_types} + conf_mat_labels = {} + for i, row in manifest.iterrows(): - fn = row['filename'] - predictions_fp = row['predictions_fp'] - annotations_fp = row['annotations_fp'] - duration = row['duration_sec'] - - metrics[fn] = get_metrics(predictions_fp, annotations_fp, args, iou, class_threshold, duration) - confusion_matrix[fn], confusion_matrix_labels = get_confusion_matrix(predictions_fp, annotations_fp, args, iou, class_threshold) - + fn = row['filename'] + annots_fp = row['annotations_fp'] + duration = row['duration_sec'] + if args.bidirectional: + row['comb_predictions_fp'], row['match_predictions_fp'] = combine_fwd_bck_preds(args.experiment_output_dir, fn, comb_iou_threshold=args.comb_iou_threshold, comb_discard_threshold=comb_discard_threshold) + + for pred_type in pred_types: + preds_fp = row[f'{pred_type}_predictions_fp'] + metrics[pred_type][fn] = get_metrics(preds_fp, annots_fp, args, iou, class_threshold, duration) + conf_mats[pred_type][fn], conf_mat_labels[pred_type] = get_confusion_matrix(preds_fp, annots_fp, args, iou, class_threshold) + if output_dir is not None: if not os.path.exists(output_dir): os.makedirs(output_dir) - + # summarize and save metrics - summary = summarize_metrics(metrics) - metrics['summary'] = summary - macro = macro_metrics(summary) - metrics['macro'] = macro + conf_mat_summaries = {} + for pred_type in pred_types: + summary = summarize_metrics(metrics[pred_type]) + metrics[pred_type]['summary'] = summary + metrics[pred_type]['macro'] = macro_metrics(summary) + conf_mat_summaries[pred_type], confusion_matrix_labels = summarize_confusion_matrix(conf_mats[pred_type], conf_mat_labels[pred_type]) + plot_confusion_matrix(conf_mat_summaries[pred_type].astype(int), confusion_matrix_labels, output_dir, name=f"cm_iou_{iou}_class_threshold_{class_threshold}_{pred_type}") if output_dir is not None: metrics_fp = os.path.join(output_dir, f'metrics_iou_{iou}_class_threshold_{class_threshold}.yaml') with open(metrics_fp, 'w') as f: yaml.dump(metrics, f) - # summarize and save confusion matrix - confusion_matrix_summary, confusion_matrix_labels = summarize_confusion_matrix(confusion_matrix, confusion_matrix_labels) - if output_dir is not None: - plot_confusion_matrix(confusion_matrix_summary.astype(int), confusion_matrix_labels, output_dir, name=f"cm_iou_{iou}_class_threshold_{class_threshold}") - - return metrics, confusion_matrix_summary + return metrics, conf_mat_summaries + +def combine_fwd_bck_preds(target_dir, fn, comb_iou_threshold, comb_discard_threshold): + fwd_preds_fp = os.path.join(target_dir, f'peaks_pred_{fn}-fwd.txt') + bck_preds_fp = os.path.join(target_dir, f'peaks_pred_{fn}-bck.txt') + comb_preds_fp = os.path.join(target_dir, f'peaks_pred_{fn}-comb.txt') + match_preds_fp = os.path.join(target_dir, f'peaks_pred_{fn}-match.txt') + fwd_preds = pd.read_csv(fwd_preds_fp, sep='\t') + bck_preds = pd.read_csv(bck_preds_fp, sep='\t') + + c = Clip() + c.load_annotations(fwd_preds_fp) + c.load_predictions(bck_preds_fp) + c.compute_matching(IoU_minimum=comb_iou_threshold) + match_preds_list = [] + for fp, bp in c.matching: + match_pred = fwd_preds.loc[fp].copy() + bck_pred = bck_preds.iloc[bp] + bp_end_time = bck_pred['End Time (s)'] + match_pred['End Time (s)'] = bp_end_time + # Sorta like assuming forward and back predictions are independent, gives a high prob for the matched predictions + match_pred['Detection Prob'] = 1 - (1-match_pred['Detection Prob'])*(1-bck_pred['Detection Prob']) + match_preds_list.append(match_pred) + + match_preds = pd.DataFrame(match_preds_list, columns=fwd_preds.columns) + + # Include the union of all predictions that weren't part of the matching + fwd_matched_idxs = [m[0] for m in c.matching] + bck_matched_idxs = [m[1] for m in c.matching] + fwd_unmatched = select_from_neg_idxs(fwd_preds, fwd_matched_idxs) + bck_unmatched = select_from_neg_idxs(bck_preds, bck_matched_idxs) + to_concat = [x for x in [match_preds, fwd_unmatched, bck_unmatched] if x.shape[0]>0] + comb_preds = pd.concat(to_concat) if len(to_concat)>0 else fwd_preds + assert len(comb_preds) == len(fwd_preds) + len(bck_preds) - len(c.matching) + + # Finally, keep only predictions above a threshold, this will include almost all matches + comb_preds = comb_preds.loc[comb_preds['Detection Prob']>comb_discard_threshold] + comb_preds.sort_values('Begin Time (s)') + comb_preds.index = list(range(len(comb_preds))) + + comb_preds.to_csv(comb_preds_fp, sep='\t', index=False) + match_preds.to_csv(match_preds_fp, sep='\t', index=False) + return comb_preds_fp, match_preds_fp + +def select_from_neg_idxs(df, neg_idxs): + bool_mask = [i not in neg_idxs for i in range(len(df))] + return df.loc[bool_mask] diff --git a/voxaboxen/evaluation/raven_utils.py b/voxaboxen/evaluation/raven_utils.py index fe5f454..ccff380 100644 --- a/voxaboxen/evaluation/raven_utils.py +++ b/voxaboxen/evaluation/raven_utils.py @@ -20,60 +20,58 @@ def __init__(self, label_set = None, unknown_label = None): self.matched_predictions = None self.label_set = label_set self.unknown_label = unknown_label - + def load_selection_table(self, fp, view = None, label_mapping = None): # view (str) : If applicable, Waveform or Spectrogram to avoid double counting # label_mapping : dict {old label : new label}. If not None, will drop annotations not in keys of label_mapping - - + + annotations = pd.read_csv(fp, delimiter = '\t') if view is None and 'View' in annotations: views = annotations['View'].unique() if len(views)>1: warnings.warn(f"I found more than one view in selection table. To avoid double counting, pass view as a parameter. Views found: {view}") - + if view is not None: annotations = annotations[annotations['View'].str.contains('Waveform')].reset_index() - + if label_mapping is not None: annotations['Annotation'] = annotations['Annotation'].map(label_mapping) annotations = annotations[~pd.isnull(annotations['Annotation'])] - + return annotations - + def load_audio(self, fp): self.samples, self.sr = librosa.load(fp, sr = None) self.duration = len(self.samples) / self.sr - - def play_audio(self, start_sec, end_sec): start_sample = int(self.sr * start_sec) end_sample = int(self.sr *end_sec) display(ipd.Audio(self.samples[start_sample:end_sample], rate = self.sr)) - + def load_annotations(self, fp, view = None, label_mapping = None): self.annotations = self.load_selection_table(fp, view = view, label_mapping = label_mapping) self.annotations['index'] = self.annotations.index - + def threshold_class_predictions(self, class_threshold): # If class probability is below a threshold, switch label to unknown - + assert self.unknown_label is not None for i in self.predictions.index: if self.predictions.loc[i, 'Class Prob'] < class_threshold: - self.predictions.at[i, 'Annotation'] = self.unknown_label - + self.predictions.at[i, 'Annotation'] = self.unknown_label + def refine_annotations(self): print("Not implemented! Could implement refining annotations by SNR to remove quiet vocs") - + def refine_predictions(self): print("Not implemented! Could implement refining predictions by SNR to remove quiet vocs") - + def load_predictions(self, fp, view = None, label_mapping = None): self.predictions = self.load_selection_table(fp, view = view, label_mapping = label_mapping) self.predictions['index'] = self.predictions.index - + def compute_matching(self, IoU_minimum = 0.5): # Bipartite graph matching between predictions and annotations # Maximizes the number of matchings with IoU > IoU_minimum @@ -83,12 +81,11 @@ def compute_matching(self, IoU_minimum = 0.5): self.matching = metrics.match_events(ref, est, min_iou=IoU_minimum, method="fast") self.matched_annotations = [p[0] for p in self.matching] self.matched_predictions = [p[1] for p in self.matching] - + def evaluate(self): - eval_sr = 50 dur_samples = int(self.duration * eval_sr) # compute frame-wise metrics at 50Hz - + if self.label_set is None: TP = len(self.matching) FP = len(self.predictions) - TP @@ -113,7 +110,7 @@ def evaluate(self): FN_seg = int(((1-seg_predictions) * seg_annotations).sum()) return {'all' : {'TP' : TP, 'FP' : FP, 'FN' : FN, 'TP_seg' : TP_seg, 'FP_seg' : FP_seg, 'FN_seg' : FN_seg}} - + else: out = {label : {'TP':0, 'FP':0, 'FN' : 0, 'TP_seg':0, 'FP_seg':0, 'FN_seg':0} for label in self.label_set} pred_label = np.array(self.predictions['Annotation']) @@ -121,14 +118,14 @@ def evaluate(self): for p in self.matching: annotation = annot_label[p[0]] prediction = pred_label[p[1]] - + if self.unknown_label is not None and prediction == self.unknown_label: pass # treat predicted unknowns as no predictions for these metrics elif annotation == prediction: out[annotation]['TP'] += 1 elif self.unknown_label is not None and annotation == self.unknown_label: out[prediction]['FP'] -= 1 #adjust FP for unknown labels - + for label in self.label_set: n_annot = int((annot_label == label).sum()) n_pred = int((pred_label == label).sum()) @@ -160,7 +157,7 @@ def evaluate(self): out[label]['FN_seg'] = FN_seg return out - + def confusion_matrix(self): if self.label_set is None: return None @@ -173,10 +170,10 @@ def confusion_matrix(self): confusion_matrix = np.zeros((confusion_matrix_size, confusion_matrix_size)) cm_nobox_idx = confusion_matrix_labels.index('None') - + pred_label = np.array(self.predictions['Annotation']) annot_label = np.array(self.annotations['Annotation']) - + for p in self.matching: annotation = annot_label[p[0]] prediction = pred_label[p[1]] @@ -187,21 +184,21 @@ def confusion_matrix(self): for label in confusion_matrix_labels: if label == 'None': continue - + # count false positive and false negative detections, regardless of class cm_label_idx = confusion_matrix_labels.index(label) - + #fp n_pred = int((pred_label == label).sum()) n_positive_detections_row = confusion_matrix.sum(1)[cm_label_idx] n_false_detections = n_pred - n_positive_detections_row confusion_matrix[cm_label_idx, cm_nobox_idx] = n_false_detections - + #fn n_annot = int((annot_label == label).sum()) n_positive_detections_col = confusion_matrix.sum(0)[cm_label_idx] n_missed_detections = n_annot - n_positive_detections_col confusion_matrix[cm_nobox_idx, cm_label_idx] = n_missed_detections - + return confusion_matrix, confusion_matrix_labels - \ No newline at end of file + diff --git a/voxaboxen/inference/inference.py b/voxaboxen/inference/inference.py index d140180..cc4a22b 100644 --- a/voxaboxen/inference/inference.py +++ b/voxaboxen/inference/inference.py @@ -5,57 +5,65 @@ from voxaboxen.inference.params import parse_inference_args from voxaboxen.training.params import load_params -from voxaboxen.model.model import DetectionModel, DetectionModelStereo, DetectionModelMultichannel -from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table +from voxaboxen.model.model import DetectionModel#, DetectionModelStereo +from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table, combine_fwd_bck_preds from voxaboxen.data.data import get_single_clip_data device = "cuda" if torch.cuda.is_available() else "cpu" def inference(inference_args): inference_args = parse_inference_args(inference_args) - args = load_params(inference_args.model_args_fp) + args = load_params(inference_args.model_args_fp) files_to_infer = pd.read_csv(inference_args.file_info_for_inference) - + output_dir = os.path.join(args.experiment_dir, 'inference') if not os.path.exists(output_dir): os.makedirs(output_dir) + + model = DetectionModel(args) - # model - if hasattr(args,'stereo') and args.stereo: - model = DetectionModelStereo(args) - elif hasattr(args,'multichannel') and args.multichannel: - model = DetectionModelMultichannel(args) + if inference_args.model_checkpoint_fp is None: + model_checkpoint_fp = os.path.join(args.experiment_dir, "final-model.pt") else: - model = DetectionModel(args) - model_checkpoint_fp = os.path.join(args.experiment_dir, "model.pt") + model_checkpoint_fp = inference_args.model_checkpoint_fp + print(f"Loading model weights from {model_checkpoint_fp}") cp = torch.load(model_checkpoint_fp) - model.load_state_dict(cp["model_state_dict"]) + if "model_state_dict" in cp.keys(): + model.load_state_dict(cp["model_state_dict"]) + else: + model.load_state_dict(cp) model = model.to(device) for i, row in files_to_infer.iterrows(): audio_fp = row['audio_fp'] fn = row['fn'] - + if not os.path.exists(audio_fp): print(f"Could not locate file {audio_fp}") continue - + try: dataloader = get_single_clip_data(audio_fp, args.clip_duration/2, args) except: print(f"Could not load file {audio_fp}") continue - + if len(dataloader) == 0: print(f"Skipping {fn} because it is too short") continue - - detections, regressions, classifications = generate_predictions(model, dataloader, args, verbose = True) - - target_fp = export_to_selection_table(detections, regressions, classifications, fn, args, verbose=True, target_dir=output_dir, detection_threshold = inference_args.detection_threshold, classification_threshold = inference_args.classification_threshold) - - print(f"Saving predictions for {fn} to {target_fp}") + + if inference_args.disable_bidirectional and not model.is_bidirectional: + print('Warning: you have passed the disable-bidirectional arg but model is not is_bidirectional') + detections, regressions, classifs, rev_detections, rev_regressions, rev_classifs = generate_predictions(model, dataloader, args, verbose = True) + fwd_target_fp = export_to_selection_table(detections, regressions, classifs, fn, args, is_bck=False, verbose=True, target_dir=output_dir, detection_threshold=inference_args.detection_threshold, classification_threshold=inference_args.classification_threshold) + if model.is_bidirectional and not inference_args.disable_bidirectional: + rev_target_fp = export_to_selection_table(rev_detections, rev_regressions, rev_classifs, fn, args, is_bck=True, verbose=True, target_dir=output_dir, detection_threshold=inference_args.detection_threshold, classification_threshold=inference_args.classification_threshold) + comb_target_fp, match_target_fp = combine_fwd_bck_preds(args.experiment_output_dir, fn, comb_iou_threshold=args.comb_iou_threshold, comb_discard_threshold=model.comb_discard_thresh.item()) + print(f"Saving predictions for {fn} to {comb_target_fp}") + + else: + print(f"Saving predictions for {fn} to {fwd_target_fp}") if __name__ == "__main__": main(sys.argv[1:]) diff --git a/voxaboxen/inference/params.py b/voxaboxen/inference/params.py index e6c5274..ef99124 100644 --- a/voxaboxen/inference/params.py +++ b/voxaboxen/inference/params.py @@ -5,11 +5,13 @@ def parse_inference_args(inference_args): parser = argparse.ArgumentParser() - + parser.add_argument('--model-args-fp', type=str, required=True, help = "filepath of model params saved as a yaml") + parser.add_argument('--model-checkpoint-fp', type=str, default=None, help = "if passed, override default of final-model.pt") parser.add_argument('--file-info-for-inference', type=str, required=True, help = "filepath of info csv listing filenames and filepaths of audio for inference") parser.add_argument('--detection-threshold', type=float, default=0.5, help="detection peaks need to be at or above this threshold to make it into the exported selection table") - parser.add_argument('--classification-threshold', type=float, default=0.0, help="classification probability needs to be at or above this threshold to not be labeled as Unknown") - - inference_args = parser.parse_args(inference_args) + parser.add_argument('--classification-threshold', type=float, default=0.5, help="classification probability needs to be at or above this threshold to not be labeled as Unknown") + parser.add_argument('--disable-bidirectional', action='store_true') + + inference_args = parser.parse_args(inference_args) return inference_args diff --git a/voxaboxen/model/model.py b/voxaboxen/model/model.py index 9e3c4aa..8280bbe 100644 --- a/voxaboxen/model/model.py +++ b/voxaboxen/model/model.py @@ -32,7 +32,7 @@ def forward(self, sig): out = self.model.extract_features(sig)[0][-1] return out - + def freeze(self): for param in self.model.encoder.parameters(): param.requires_grad = False @@ -40,15 +40,22 @@ def freeze(self): def unfreeze(self): for param in self.model.encoder.parameters(): param.requires_grad = True - + class DetectionModel(nn.Module): def __init__(self, args, embedding_dim=768): super().__init__() + self.is_bidirectional = args.bidirectional if hasattr(args, "bidirectional") else False + self.is_stereo = args.stereo if hasattr(args, "stereo") else False + self.is_segmentation = args.segmentation_based if hasattr(args, "segmentation_based") else False + if self.is_stereo: + embedding_dim *= 2 self.encoder = AvesEmbedding(args) self.args = args aves_sr = args.sr // args.scale_factor self.detection_head = DetectionHead(args, embedding_dim = embedding_dim) - + if self.is_bidirectional: + self.rev_detection_head = DetectionHead(args, embedding_dim = embedding_dim) + def forward(self, x): """ Input @@ -59,22 +66,32 @@ def forward(self, x): class_logits (Tensor): (batch, time, n_classes) (time at 50 Hz, aves_sr) """ - - expected_dur_output = math.ceil(x.size(1)/self.args.scale_factor) - - x = x-torch.mean(x,axis=1,keepdim=True) - feats = self.encoder(x) - + + expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) + + x = x-torch.mean(x,axis=-1,keepdim=True) + if self.is_stereo: + feats0 = self.encoder(x[:,0,:]) + feats1 = self.encoder(x[:,1,:]) + feats = torch.cat([feats0,feats1],dim=-1) + else: + feats = self.encoder(x) + #aves may be off by 1 sample from expected pad = expected_dur_output - feats.size(1) if pad>0: feats = F.pad(feats, (0,0,0,pad), mode='reflect') - + detection_logits, regression, class_logits = self.detection_head(feats) detection_probs = torch.sigmoid(detection_logits) - - return detection_probs, regression, class_logits - + if self.is_bidirectional: + rev_detection_logits, rev_regression, rev_class_logits = self.rev_detection_head(feats) + rev_detection_probs = torch.sigmoid(rev_detection_logits) + else: + rev_detection_probs = rev_regression = rev_class_logits = None + + return detection_probs, regression, class_logits, rev_detection_probs, rev_regression, rev_class_logits + def generate_features(self, x): """ Input @@ -82,22 +99,22 @@ def generate_features(self, x): Returns features (Tensor): (batch, time, embedding_dim) (time at 50 Hz, aves_sr) """ - + expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) - + x = x-torch.mean(x,axis=-1,keepdim=True) feats = self.encoder(x) - + #aves may be off by 1 sample from expected pad = expected_dur_output - feats.size(1) if pad>0: feats = F.pad(feats, (0,0,0,pad), mode='reflect') - + return feats - + def freeze_encoder(self): self.encoder.freeze() - + def unfreeze_encoder(self): self.encoder.unfreeze() @@ -107,7 +124,7 @@ def __init__(self, args, embedding_dim=768): self.n_classes = len(args.label_set) self.head = nn.Conv1d(embedding_dim, 2+self.n_classes, args.prediction_scale_factor, stride=args.prediction_scale_factor, padding=0) self.args=args - + def forward(self, x): """ Input @@ -121,116 +138,42 @@ def forward(self, x): x = rearrange(x, 'b t c -> b c t') x = self.head(x) x = rearrange(x, 'b c t -> b t c') - detection_logits = x[:,:,0] + detection_logits = x[:,:,0] reg = x[:,:,1] class_logits = x[:,:,2:] return detection_logits, reg, class_logits - -class DetectionModelStereo(DetectionModel): - def __init__(self, args, embedding_dim=768): - super().__init__(args, embedding_dim=2*embedding_dim) - - def forward(self, x): - """ - Input - x (Tensor): (batch, channels, time) (time at 16000 Hz, audio_sr) - Returns - detection_probs (Tensor): (batch, time,) (time at 50 Hz, aves_sr) - regression (Tensor): (batch, time,) (time at 50 Hz, aves_sr) - class_logits (Tensor): (batch, time, n_classes) (time at 50 Hz, aves_sr) - - """ - - expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) - - x = x-torch.mean(x,axis=-1,keepdim=True) - feats0 = self.encoder(x[:,0,:]) - feats1 = self.encoder(x[:,1,:]) - feats = torch.cat([feats0,feats1],dim=-1) - - #aves may be off by 1 sample from expected - pad = expected_dur_output - feats.size(1) - if pad>0: - feats = F.pad(feats, (0,0,0,pad), mode='reflect') - - detection_logits, regression, class_logits = self.detection_head(feats) - detection_probs = torch.sigmoid(detection_logits) - - return detection_probs, regression, class_logits - -class DetectionModelMultichannel(DetectionModel): - # supports >1 channel, but unlike Stereo model does not assume the order of channels matters. - def __init__(self, args, embedding_dim=768): - super().__init__(args, embedding_dim=embedding_dim) - self.n_classes = len(args.label_set) - - def forward(self, x): - """ - Input - x (Tensor): (batch, channels, time) (time at 16000 Hz, audio_sr) - Returns - detection_probs (Tensor): (batch, time,) (time at 50 Hz, aves_sr) - regression (Tensor): (batch, time,) (time at 50 Hz, aves_sr) - class_logits (Tensor): (batch, time, n_classes) (time at 50 Hz, aves_sr) - - """ - + +# class DetectionModelStereo(DetectionModel): +# def __init__(self, args, embedding_dim=768): +# super().__init__(args, embedding_dim=2*embedding_dim) + +# def forward(self, x): +# """ +# Input +# x (Tensor): (batch, channels, time) (time at 16000 Hz, audio_sr) +# Returns +# detection_probs (Tensor): (batch, time,) (time at 50 Hz, aves_sr) +# regression (Tensor): (batch, time,) (time at 50 Hz, aves_sr) +# class_logits (Tensor): (batch, time, n_classes) (time at 50 Hz, aves_sr) + +# """ + # expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) - + # x = x-torch.mean(x,axis=-1,keepdim=True) -# feats=[] -# for i in range(x.size(1)): -# feats.append(self.encoder(x[:,i,:])) -# feats = sum(feats) +# feats0 = self.encoder(x[:,0,:]) +# feats1 = self.encoder(x[:,1,:]) +# feats = torch.cat([feats0,feats1],dim=-1) # #aves may be off by 1 sample from expected # pad = expected_dur_output - feats.size(1) # if pad>0: # feats = F.pad(feats, (0,0,0,pad), mode='reflect') - + # detection_logits, regression, class_logits = self.detection_head(feats) # detection_probs = torch.sigmoid(detection_logits) - - expected_dur_output = math.ceil(x.size(-1)/self.args.scale_factor) - - x = x-torch.mean(x,axis=-1,keepdim=True) - - detection_logits = [] - regression = [] - class_logits = [] - - for i in range(x.size(1)): - feats = self.encoder(x[:,i,:]) - #aves may be off by 1 sample from expected - pad = expected_dur_output - feats.size(1) - if pad>0: - feats = F.pad(feats, (0,0,0,pad), mode='reflect') - - dl, rg, cl = self.detection_head(feats) - detection_logits.append(dl) - regression.append(rg) - class_logits.append(cl) - - detection_logits = torch.stack(detection_logits, dim=0) - regression = torch.stack(regression, dim=0) - class_logits = torch.stack(class_logits, dim=0) - - if hasattr(self.args, "segmentation_based") and self.args.segmentation_based: - detection_logits=torch.max(detection_logits,dim=0)[0] - regression=torch.max(regression,dim=0)[0] - class_logits=torch.max(class_logits,dim=0)[0] - else: - mask_based_on_logits = torch.eq(torch.max(detection_logits, dim=0,keepdim=True)[0], detection_logits) - - detection_logits = torch.sum(detection_logits*mask_based_on_logits, dim=0) - regression = torch.sum(regression*mask_based_on_logits,dim=0) - class_logits = torch.sum(class_logits* mask_based_on_logits.unsqueeze(-1).repeat(1,1,1,self.n_classes), dim=0) - - detection_probs = torch.sigmoid(detection_logits) - - return detection_probs, regression, class_logits - +# return detection_probs, regression, class_logits def rms_and_mixup(X, d, r, y, train, args): if args.rms_norm: @@ -238,7 +181,7 @@ def rms_and_mixup(X, d, r, y, train, args): ms = ms + torch.full_like(ms, 1e-6) rms = ms ** (-1/2) X = X * rms - + if args.mixup and train: # TODO: For mixup, add in a check that there aren't extremely overlapping vocs @@ -256,13 +199,12 @@ def rms_and_mixup(X, d, r, y, train, args): r_aug = torch.flip(r, (0,)) * mask[:,:,0] y_aug = torch.flip(y, (0,)) * mask - X = (X + X_aug)#[:batch_size//2,...] - d = torch.maximum(d, d_aug)#[:batch_size//2,...] - r = torch.maximum(r, r_aug)#[:batch_size//2,...] - y = torch.maximum(y, y_aug)#[:batch_size//2,...] - + X = (X + X_aug) + d = torch.maximum(d, d_aug) + r = torch.maximum(r, r_aug) + y = torch.maximum(y, y_aug) if args.rms_norm: X = X * (1/2) - + return X, d, r, y - + diff --git a/voxaboxen/project/project_setup.py b/voxaboxen/project/project_setup.py index aa8b9ee..b06577c 100644 --- a/voxaboxen/project/project_setup.py +++ b/voxaboxen/project/project_setup.py @@ -7,37 +7,37 @@ def project_setup(args): args = parse_project_args(args) - + if not os.path.exists(args.project_dir): os.makedirs(args.project_dir) - + all_annots = [] for info_fp in [args.train_info_fp, args.val_info_fp, args.test_info_fp]: if info_fp is None: continue - + info = pd.read_csv(info_fp) annot_fps = list(info['selection_table_fp']) - + for annot_fp in annot_fps: if annot_fp != "None": selection_table = pd.read_csv(annot_fp, delimiter = '\t') annots = list(selection_table['Annotation'].astype(str)) all_annots.extend(annots) - + label_set = sorted(set(all_annots)) label_mapping = {x : x for x in label_set} label_mapping['Unknown'] = 'Unknown' unknown_label = 'Unknown' - + if unknown_label in label_set: label_set.remove(unknown_label) - + setattr(args, "label_set", label_set) setattr(args, "label_mapping", label_mapping) setattr(args, "unknown_label", unknown_label) - + save_params(args) if __name__ == "__main__": - project_setup(sys.argv[1:]) \ No newline at end of file + project_setup(sys.argv[1:]) diff --git a/voxaboxen/training/params.py b/voxaboxen/training/params.py index 806d63c..ab2b06c 100644 --- a/voxaboxen/training/params.py +++ b/voxaboxen/training/params.py @@ -8,10 +8,12 @@ def parse_args(args,allow_unknown=False): parser = argparse.ArgumentParser() - + # General parser.add_argument('--name', type = str, required=True) parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--is-test', '-t', action='store_true', help='run a quick version for testing') + parser.add_argument('--overwrite', action='store_true', help='overwrite an experiment of the same name, if it exists') # Data parser.add_argument('--project-config-fp', type = str, required=True) @@ -19,8 +21,9 @@ def parse_args(args,allow_unknown=False): parser.add_argument('--clip-hop', type=float, default=None, help = "clip hop, in seconds. If None, automatically set to be half clip duration. Used only during training; clip hop is automatically set to be 1/2 clip duration for inference") parser.add_argument('--train-info-fp', type=str, required=False, help = "train info, to override project train info") parser.add_argument('--num-workers', type=int, default=8) - + # Model + parser.add_argument('--bidirectional', action='store_true', help="train and inference in both directions and combine results") parser.add_argument('--sr', type=int, default=16000) parser.add_argument('--scale-factor', type=int, default = 320, help = "downscaling performed by aves") parser.add_argument('--aves-config-fp', type=str, default = "weights/birdaves-biox-base.torchaudio.model_config.json") @@ -32,13 +35,16 @@ def parse_args(args,allow_unknown=False): parser.add_argument('--stereo', action='store_true', help="If passed, will process stereo data as stereo. order of channels matters") parser.add_argument('--multichannel', action='store_true', help="If passed, will encode each audio channel seperately, then add together the encoded audio before final layer") parser.add_argument('--segmentation-based', action='store_true', help="If passed, will make predictions based on frame-wise segmentations rather than box starts") - + parser.add_argument('--comb-discard-thresh', type=float, default=0.75, help="If bidirectional, sets threshold for combining forward and backward predictions") + parser.add_argument('--comb-iou-threshold', type=float, default=0.5, help="minimum iou to match a forward and backward prediction") + # parser.add_argument('--reload-from', type=str) + # Training - parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--lr', type=float, default=.00005) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--lr', type=float, default=.00005) parser.add_argument('--n-epochs', type=int, default=50) parser.add_argument('--unfreeze-encoder-epoch', type=int, default=3) - parser.add_argument('--end-mask-perc', type=float, default = 0.1, help="During training, mask loss from a percentage of the frames on each end of the clip") + parser.add_argument('--end-mask-perc', type=float, default = 0.1, help="During training, mask loss from a percentage of the frames on each end of the clip") parser.add_argument('--omit-empty-clip-prob', type=float, default=0, help="if a clip has no annotations, do not use for training with this probability") parser.add_argument('--lamb', type=float, default=.04, help="parameter controlling strength regression loss") parser.add_argument('--rho', type=float, default = .01, help="parameter controlling strength of classification loss") @@ -48,13 +54,13 @@ def parse_args(args,allow_unknown=False): parser.add_argument('--early-stopping', action ="store_true", help="Whether to use early stopping based on val performance") parser.add_argument('--pos-loss-weight', type=float, default=1, help="Weights positive component of loss") - + # Augmentations - parser.add_argument('--amp-aug', action ="store_true", help="Whether to use amplitude augmentation") - parser.add_argument('--amp-aug-low-r', type=float, default = 0.8) - parser.add_argument('--amp-aug-high-r', type=float, default = 1.2) - parser.add_argument('--mixup', action ="store_true", help="Whether to use mixup augmentation") - + parser.add_argument('--amp-aug', action ="store_true", help="Whether to use amplitude augmentation") + parser.add_argument('--amp-aug-low-r', type=float, default = 0.8) + parser.add_argument('--amp-aug-high-r', type=float, default = 1.2) + parser.add_argument('--mixup', action ="store_true", help="Whether to use mixup augmentation") + # Inference parser.add_argument('--peak-distance', type=float, default=5, help="for finding peaks in detection probability, what radius to use for detecting local maxima. In output frame rate.") parser.add_argument('--nms', type = str, default='soft_nms', choices = ['none', 'nms', 'soft_nms'], help="Whether to apply additional nms after finding peaks") @@ -68,13 +74,13 @@ def parse_args(args,allow_unknown=False): args, remaining = parser.parse_known_args(args) else: args = parser.parse_args(args) - + args = read_config(args) check_config(args) if args.clip_hop is None: setattr(args, "clip_hop", args.clip_duration/2) - + if allow_unknown: return args, remaining else: @@ -83,10 +89,10 @@ def parse_args(args,allow_unknown=False): def read_config(args): with open(args.project_config_fp, 'r') as f: project_config = yaml.safe_load(f) - + for key in project_config: setattr(args,key,project_config[key]) - + return args def set_seed(seed): @@ -107,7 +113,7 @@ def save_params(args): with open(params_file, "w") as f: yaml.dump(args_dict, f) - + def load_params(fp): with open(fp, 'r') as f: args_dict = yaml.safe_load(f) @@ -125,4 +131,5 @@ def check_config(args): if args.segmentation_based and (args.rho!=1): import warnings warnings.warn("when using segmentation-based framework, recommend setting args.rho=1") - \ No newline at end of file + if args.bidirectional and args.segmentation_based: + raise ValueError("bidirectional and segmentation settings are not currently compatible") diff --git a/voxaboxen/training/train.py b/voxaboxen/training/train.py index 58abb97..852feb5 100644 --- a/voxaboxen/training/train.py +++ b/voxaboxen/training/train.py @@ -21,26 +21,19 @@ warnings.warn("Only using CPU! Check CUDA") def train(model, args): - model = model.to(device) - - if args.previous_checkpoint_fp is not None: - print(f"loading model weights from {args.previous_checkpoint_fp}") - cp = torch.load(args.previous_checkpoint_fp) - model.load_state_dict(cp["model_state_dict"]) - + detection_loss_fn = get_detection_loss_fn(args) reg_loss_fn = get_reg_loss_fn(args) - class_loss_fn = get_class_loss_fn(args) - + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad = True) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step_size, gamma=0.1, last_epoch=- 1, verbose=False) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs, eta_min=0, last_epoch=- 1, verbose=False) - + train_evals = [] learning_rates = [] val_evals = [] - + if args.early_stopping: assert args.val_info_fp is not None best_f1 = 0 @@ -50,24 +43,25 @@ def train(model, args): use_val = True else: use_val = False - + for t in range(args.n_epochs): print(f"Epoch {t}\n-------------------------------") train_dataloader = get_train_dataloader(args, random_seed_shift = t) # reinitialize dataloader with different negatives each epoch model, train_eval = train_epoch(model, t, train_dataloader, detection_loss_fn, reg_loss_fn, class_loss_fn, optimizer, args) train_evals.append(train_eval.copy()) learning_rates.append(optimizer.param_groups[0]["lr"]) - + train_evals_by_epoch = {i : e for i, e in enumerate(train_evals)} train_evals_fp = os.path.join(args.experiment_dir, "train_history.yaml") with open(train_evals_fp, 'w') as f: yaml.dump(train_evals_by_epoch, f) - + if use_val: - val_eval = val_epoch(model, t, val_dataloader, detection_loss_fn, reg_loss_fn, class_loss_fn, args) - val_evals.append(val_eval.copy()) - plot_eval(train_evals, learning_rates, args, val_evals = val_evals) - + eval_scores = val_epoch(model, t, val_dataloader, args) + # TODO: maybe plot evals for other pred_types + val_evals.append(eval_scores['fwd'].copy()) + plot_eval(train_evals, learning_rates, args, val_evals=val_evals) + val_evals_by_epoch = {i : e for i, e in enumerate(val_evals)} val_evals_fp = os.path.join(args.experiment_dir, "val_history.yaml") with open(val_evals_fp, 'w') as f: @@ -75,9 +69,9 @@ def train(model, args): else: plot_eval(train_evals, learning_rates, args) scheduler.step() - + if use_val and args.early_stopping: - current_f1 = val_eval['f1'] + current_f1 = eval_scores['comb']['f1'] if model.is_bidirectional else eval_scores['fwd']['f1'] if current_f1 > best_f1: print('found new best model') best_f1 = current_f1 @@ -90,13 +84,13 @@ def train(model, args): "train_evals": train_evals, "val_evals" : val_evals } - + torch.save( checkpoint_dict, - os.path.join(args.experiment_dir, f"model.pt"), - ) - - else: + os.path.join(args.experiment_dir, "model.pt"), + ) + + else: checkpoint_dict = { "epoch": t, "model_state_dict": model.state_dict(), @@ -105,110 +99,143 @@ def train(model, args): "train_evals": train_evals, "val_evals" : val_evals } - + torch.save( checkpoint_dict, - os.path.join(args.experiment_dir, f"model.pt"), - ) - - + os.path.join(args.experiment_dir, "model.pt"), + ) + + print("Done!") - - cp = torch.load(os.path.join(args.experiment_dir, f"model.pt")) + + cp = torch.load(os.path.join(args.experiment_dir, "model.pt")) model.load_state_dict(cp["model_state_dict"]) - + # resave validation with best model if use_val: - val_epoch(model, t+1, val_dataloader, detection_loss_fn, reg_loss_fn, class_loss_fn, args) - - return model - + val_epoch(model, args.n_epochs, val_dataloader, args) + + return model + +def lf(dets, det_preds, regs, reg_preds, y, y_preds, args, det_loss_fn, reg_loss_fn, class_loss_fn): + # We mask out loss from each end of the clip, so the model isn't forced to learn to detect events that are partially cut off. + # This does not affect inference, because during inference we overlap clips at 50% + + end_mask_perc = args.end_mask_perc + end_mask_dur = int(det_preds.size(1)*end_mask_perc) + + det_preds_clipped = det_preds[:,end_mask_dur:-end_mask_dur] + dets_clipped = dets[:,end_mask_dur:-end_mask_dur] + + reg_preds_clipped = reg_preds[:,end_mask_dur:-end_mask_dur] + regs_clipped = regs[:,end_mask_dur:-end_mask_dur] + + y_clipped = y[:,end_mask_dur:-end_mask_dur,:] + + detection_loss = det_loss_fn(det_preds_clipped, dets_clipped, pos_loss_weight=args.pos_loss_weight) + reg_loss = reg_loss_fn(reg_preds_clipped, regs_clipped, dets_clipped, y_clipped) + if len(args.label_set)==1: + class_loss = torch.tensor(0) + else: + y_preds_clipped = y_preds[:,end_mask_dur:-end_mask_dur,:] + class_loss = class_loss_fn(y_preds_clipped, y_clipped, dets_clipped) + return detection_loss, reg_loss, class_loss + def train_epoch(model, t, dataloader, detection_loss_fn, reg_loss_fn, class_loss_fn, optimizer, args): model.train() if t < args.unfreeze_encoder_epoch: model.freeze_encoder() else: model.unfreeze_encoder() - - + + evals = {} train_loss = 0; losses = []; detection_losses = []; regression_losses = []; class_losses = [] + rev_train_loss = 0; rev_losses = []; rev_detection_losses = []; rev_regression_losses = []; rev_class_losses = [] + data_iterator = tqdm.tqdm(dataloader) - for i, (X, d, r, y) in enumerate(data_iterator): + #for i, (X, d, r, y, rev_d, rev_r, rev_y) in enumerate(data_iterator): + for i, batch in enumerate(data_iterator): num_batches_seen = i - X = X.to(device = device, dtype = torch.float) - d = d.to(device = device, dtype = torch.float) - r = r.to(device = device, dtype = torch.float) - y = y.to(device = device, dtype = torch.float) - + batch = [item.to(device, dtype=torch.float) for item in batch] + X, d, r, y = batch[:4] X, d, r, y = rms_and_mixup(X, d, r, y, True, args) - probs, regression, class_logits = model(X) - - # We mask out loss from each end of the clip, so the model isn't forced to learn to detect events that are partially cut off. - # This does not affect inference, because during inference we overlap clips at 50% - - end_mask_perc = args.end_mask_perc - end_mask_dur = int(probs.size(1)*end_mask_perc) - - d_clipped = d[:,end_mask_dur:-end_mask_dur] - probs_clipped = probs[:,end_mask_dur:-end_mask_dur] - - regression_clipped = regression[:,end_mask_dur:-end_mask_dur] - r_clipped = r[:,end_mask_dur:-end_mask_dur] - - class_logits_clipped = class_logits[:,end_mask_dur:-end_mask_dur,:] - y_clipped = y[:,end_mask_dur:-end_mask_dur,:] - - detection_loss = detection_loss_fn(probs_clipped, d_clipped, pos_loss_weight = args.pos_loss_weight) - reg_loss = reg_loss_fn(regression_clipped, r_clipped, d_clipped, y_clipped) - class_loss = class_loss_fn(class_logits_clipped, y_clipped, d_clipped) - + probs, regression, class_logits, rev_probs, rev_regression, rev_class_logits = model(X) + #model_outputs = model(X) + #probs, regression, class_logits = model_outputs[:3] + detection_loss, reg_loss, class_loss = lf(d, probs, r, regression, y, class_logits, args=args, det_loss_fn=detection_loss_fn, reg_loss_fn=reg_loss_fn, class_loss_fn=class_loss_fn) + loss = args.rho * class_loss + detection_loss + args.lamb * reg_loss train_loss += loss.item() losses.append(loss.item()) detection_losses.append(detection_loss.item()) regression_losses.append(args.lamb * reg_loss.item()) class_losses.append(args.rho * class_loss.item()) - - # Backpropagation + + pbar_str = f"loss {np.mean(losses[-10:]):.5f}, det {np.mean(detection_losses[-10:]):.5f}, reg {np.mean(regression_losses[-10:]):.5f}, class {np.mean(class_losses[-10:]):.5f}" + + if model.is_bidirectional: + assert all(x is not None for x in [rev_probs, rev_regression, rev_class_logits]) + rev_d, rev_r, rev_y = batch[4:] + #rev_probs, rev_regression, rev_class_logits = model_outputs[3:] + _, rev_d, rev_r, rev_y = rms_and_mixup(X, rev_d, rev_r, rev_y, True, args) + + rev_detection_loss, rev_reg_loss, rev_class_loss = lf(rev_d, rev_probs, rev_r, rev_regression, rev_y, rev_class_logits, args=args, det_loss_fn=detection_loss_fn, reg_loss_fn=reg_loss_fn, class_loss_fn=class_loss_fn) + rev_loss = args.rho * rev_class_loss + rev_detection_loss + args.lamb * rev_reg_loss + rev_train_loss += rev_loss.item() + rev_losses.append(rev_loss.item()) + rev_detection_losses.append(rev_detection_loss.item()) + rev_regression_losses.append(args.lamb * rev_reg_loss.item()) + rev_class_losses.append(args.rho * rev_class_loss.item()) + loss = (loss + rev_loss)/2 + + pbar_str += f" revloss {np.mean(rev_losses[-10:]):.5f}, revdet {np.mean(rev_detection_losses[-10:]):.5f}, revreg {np.mean(rev_regression_losses[-10:]):.5f}, revclass {np.mean(rev_class_losses[-10:]):.5f}" + else: + assert all(x is None for x in [rev_probs, rev_regression, rev_class_logits]) + + optimizer.zero_grad() loss.backward() - + optimizer.step() if i > 10: - data_iterator.set_description(f"Loss {np.mean(losses[-10:]):.7f}, Detection Loss {np.mean(detection_losses[-10:]):.7f}, Regression Loss {np.mean(regression_losses[-10:]):.7f}, Classification Loss {np.mean(class_losses[-10:]):.7f}") - + data_iterator.set_description(pbar_str) + + if args.is_test and i == 15: break + train_loss = train_loss / num_batches_seen evals['loss'] = float(train_loss) - + print(f"Epoch {t} | Train loss: {train_loss:1.3f}") return model, evals - -def val_epoch(model, t, dataloader, detection_loss_fn, reg_loss_fn, class_loss_fn, args): + +def val_epoch(model, t, dataloader, args): model.eval() - + manifest = predict_and_generate_manifest(model, dataloader, args, verbose = False) - e, _ = evaluate_based_on_manifest(manifest, args, output_dir = os.path.join(args.experiment_dir, 'val_results'), iou = args.model_selection_iou, class_threshold = args.model_selection_class_threshold) - - summary = e['summary'] - - evals = {k:[] for k in ['precision','recall','f1']} - for k in ['precision','recall','f1']: - for l in args.label_set: - m = summary[l][k] - evals[k].append(m) - evals[k] = float(np.mean(evals[k])) - - print(f"Epoch {t} | Val scores @{args.model_selection_iou}IoU: Precision: {evals['precision']:1.3f} Recall: {evals['recall']:1.3f} F1: {evals['f1']:1.3f}") + e, _ = evaluate_based_on_manifest(manifest, args, output_dir=os.path.join(args.experiment_dir, 'val_results'), iou=args.model_selection_iou, class_threshold=args.model_selection_class_threshold, comb_discard_threshold=args.comb_discard_threshold) + + print(f"Epoch {t} | val@{args.model_selection_iou}IoU:") + evals = {} + for pt in e.keys(): + evals[pt] = {k:[] for k in ['precision','recall','f1', 'precision_seg', 'recall_seg', 'f1_seg']} + for k in ['precision','recall','f1', 'precision_seg', 'recall_seg', 'f1_seg']: + for l in args.label_set: + m = e[pt]['summary'][l][k] + evals[pt][k].append(m) + evals[pt][k] = float(np.mean(evals[pt][k])) + + print(f"{pt}prec: {evals[pt]['precision']:1.3f} {pt}rec: {evals[pt]['recall']:1.3f} {pt}F1: {evals[pt]['f1']:1.3f} {pt}prec_seg: {evals[pt]['precision_seg']:1.3f} {pt}rec_seg: {evals[pt]['recall_seg']:1.3f} {pt}F1_seg: {evals[pt]['f1_seg']:1.3f}", end=' ') + print() return evals def modified_focal_loss(pred, gt, pos_loss_weight = 1): # Modified from https://github.com/xingyizhou/CenterNet/blob/2b7692c377c6686fb35e473dac2de6105eed62c6/src/lib/models/losses.py - ''' + ''' pred [batch, time,] gt [batch, time,] - ''' - + ''' + pos_inds = gt.eq(1).float() neg_inds = gt.lt(1).float() @@ -218,48 +245,48 @@ def modified_focal_loss(pred, gt, pos_loss_weight = 1): pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds * pos_loss_weight neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds - + loss = -1.*(neg_loss + pos_loss) - + loss = loss.mean() return loss - - + + def masked_reg_loss(regression, r, d, y, class_weights = None): # regression, r (Tensor): [batch, time,] # r (Tensor) : [batch, time,], float tensor # d (Tensor) : [batch, time,], float tensor # y (Tensor) : [batch, time, n_classes] # class_weights (Tensor) : [n_classes,] - + reg_loss = F.l1_loss(regression, r, reduction='none') mask = d.eq(1).float() - + reg_loss = reg_loss * mask - + if class_weights is not None: y = rearrange(y, 'b t c -> b c t') high_prob = torch.amax(y, dim = 1) knowns = high_prob.eq(1).float() unknowns = high_prob.lt(1).float() - + reg_loss_unknowns = reg_loss * unknowns - + class_weights = torch.reshape(class_weights, (1, -1, 1)) class_weights = y * class_weights class_weights = torch.amax(class_weights, dim = 1) - + reg_loss_knowns = reg_loss * knowns * class_weights - + reg_loss = reg_loss_unknowns + reg_loss_knowns - + reg_loss = torch.sum(reg_loss) n_pos = mask.sum() - + if n_pos>0: reg_loss = reg_loss / n_pos - + return reg_loss def masked_classification_loss(class_logits, y, d, class_weights = None): @@ -267,30 +294,30 @@ def masked_classification_loss(class_logits, y, d, class_weights = None): # y (Tensor): [batch, time,n_classes] # d (Tensor) : [batch, time,], float tensor # class_weight : [n_classes,], float tensor - + class_logits = rearrange(class_logits, 'b t c -> b c t') y = rearrange(y, 'b t c -> b c t') - + high_prob = torch.amax(y, dim = 1) knowns = high_prob.eq(1).float() unknowns = high_prob.lt(1).float() - + mask = d.eq(1).float() # mask out time steps where no event is present - + known_class_loss = F.cross_entropy(class_logits, y, weight=class_weights, reduction='none') known_class_loss = known_class_loss * mask * knowns known_class_loss = torch.sum(known_class_loss) - + unknown_class_loss = F.cross_entropy(class_logits, y, weight=None, reduction='none') unknown_class_loss = unknown_class_loss * mask * unknowns unknown_class_loss = torch.sum(unknown_class_loss) - + class_loss = known_class_loss + unknown_class_loss n_pos = mask.sum() - + if n_pos>0: class_loss = class_loss / n_pos - + return class_loss def segmentation_loss(class_logits, y, d, class_weights=None): @@ -302,7 +329,6 @@ def segmentation_loss(class_logits, y, d, class_weights=None): default_focal_loss = torchvision.ops.sigmoid_focal_loss(class_logits, y, reduction='mean') return default_focal_loss - def get_class_loss_fn(args): if hasattr(args,"segmentation_based") and args.segmentation_based: return segmentation_loss @@ -310,10 +336,10 @@ def get_class_loss_fn(args): dataloader_temp = get_train_dataloader(args, random_seed_shift = 0) class_proportions = dataloader_temp.dataset.get_class_proportions() class_weights = 1. / (class_proportions + 1e-6) - + class_weights = class_weights * (class_proportions>0) # ignore weights for unrepresented classes + class_weights = (1. / (np.mean(class_weights) + 1e-6)) * class_weights # normalize so average weight = 1 - - print(f"Using class weights {class_weights}") + print(f"Using class weights {class_weights}") class_weights = torch.Tensor(class_weights).to(device) return partial(masked_classification_loss, class_weights = class_weights) @@ -327,7 +353,8 @@ def zrl(regression, r, d, y, class_weights = None): dataloader_temp = get_train_dataloader(args, random_seed_shift = 0) class_proportions = dataloader_temp.dataset.get_class_proportions() class_weights = 1. / (class_proportions + 1e-6) - + class_weights = class_weights * (class_proportions>0) # ignore weights for unrepresented classes + class_weights = (1. / (np.mean(class_weights) + 1e-6)) * class_weights # normalize so average weight = 1 class_weights = torch.Tensor(class_weights).to(device) @@ -340,4 +367,3 @@ def zdl(pred, gt, pos_loss_weight = 1): return zdl else: return modified_focal_loss - \ No newline at end of file diff --git a/voxaboxen/training/train_model.py b/voxaboxen/training/train_model.py index b59b128..480c6ad 100644 --- a/voxaboxen/training/train_model.py +++ b/voxaboxen/training/train_model.py @@ -1,50 +1,61 @@ +import pandas as pd +import torch from voxaboxen.data.data import get_test_dataloader -from voxaboxen.model.model import DetectionModel, DetectionModelStereo, DetectionModelMultichannel +from voxaboxen.model.model import DetectionModel from voxaboxen.training.train import train from voxaboxen.training.params import parse_args, set_seed, save_params -from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table, get_metrics, summarize_metrics, predict_and_generate_manifest, evaluate_based_on_manifest - -import yaml +from voxaboxen.evaluation.evaluation import predict_and_generate_manifest, evaluate_based_on_manifest import sys import os +device = "cuda" if torch.cuda.is_available() else "cpu" + +def print_metrics(metrics, just_one_label): + for pred_type in metrics.keys(): + to_print = {k1:{k:round(100*v,4) for k,v in v1.items()} for k1,v1 in metrics[pred_type]['summary'].items()} if just_one_label else dict(pd.DataFrame(metrics[pred_type]['summary']).mean(axis=1).round(4)) + print(f'{pred_type}:', to_print) + def train_model(args): ## Setup args = parse_args(args) set_seed(args.seed) - + experiment_dir = os.path.join(args.project_dir, args.name) setattr(args, 'experiment_dir', str(experiment_dir)) - if not os.path.exists(args.experiment_dir): - os.makedirs(args.experiment_dir) - + if os.path.exists(args.experiment_dir) and (not args.overwrite) and args.name!='demo': + sys.exit('experiment already exists with this name') + experiment_output_dir = os.path.join(experiment_dir, "outputs") setattr(args, 'experiment_output_dir', experiment_output_dir) if not os.path.exists(args.experiment_output_dir): os.makedirs(args.experiment_output_dir) - + save_params(args) - if hasattr(args,'stereo') and args.stereo: - model = DetectionModelStereo(args) - elif hasattr(args,'multichannel') and args.multichannel: - model = DetectionModelMultichannel(args) - else: - model = DetectionModel(args) - + model = DetectionModel(args).to(device) + + if args.previous_checkpoint_fp is not None: + print(f"loading model weights from {args.previous_checkpoint_fp}") + cp = torch.load(args.previous_checkpoint_fp) + if "model_state_dict" in cp.keys(): + model.load_state_dict(cp["model_state_dict"]) + else: + model.load_state_dict(cp) + ## Training - trained_model = train(model, args) + if args.n_epochs>0: + model = train(model, args) - ## Evaluation test_dataloader = get_test_dataloader(args) - - manifest = predict_and_generate_manifest(trained_model, test_dataloader, args) - + test_manifest = predict_and_generate_manifest(model, test_dataloader, args) for iou in [0.2, 0.5, 0.8]: - for class_threshold in [0.0, 0.5, 0.95]: - evaluate_based_on_manifest(manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou = iou, class_threshold = class_threshold) + test_metrics, test_conf_mats = evaluate_based_on_manifest(test_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=iou, class_threshold=0.5, comb_discard_threshold=args.comb_discard_thresh) + print(f'Test with IOU{iou}') + print_metrics(test_metrics, just_one_label=(len(args.label_set)==1)) + + torch.save(model.state_dict(), os.path.join(args.experiment_dir, 'final-model.pt')) if __name__ == "__main__": train_model(sys.argv[1:]) - + # python main.py --name=debug --lr=0.0001 --n-epochs=6 --clip-duration=4 --batch-size=100 --omit-empty-clip-prob=0.5 --clip-hop=2