diff --git a/voxaboxen/evaluation/evaluation.py b/voxaboxen/evaluation/evaluation.py index b6d5362..8ddbb9f 100644 --- a/voxaboxen/evaluation/evaluation.py +++ b/voxaboxen/evaluation/evaluation.py @@ -227,7 +227,7 @@ def generate_features(model, single_clip_dataloader, args, verbose = True): return all_features.detach().cpu().numpy() -def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=True, target_dir=None, classif_threshold=0): +def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=True, target_dir=None, detection_threshold=0, classif_threshold=0): if target_dir is None: target_dir = args.experiment_output_dir @@ -248,7 +248,7 @@ def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=Tr # np.save(target_fp, classifs) ## peaks - det_peaks, properties = find_peaks(dets, height=args.detection_threshold, distance=args.peak_distance) + det_peaks, properties = find_peaks(dets, height=detection_threshold, distance=args.peak_distance) det_probs = properties['peak_heights'] ## regs and classifs @@ -278,7 +278,7 @@ def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=Tr bboxes, det_probs, class_idxs, class_probs = pred2bbox(det_peaks, det_probs, durations, class_idxs, class_probs, pred_sr, is_bck) if args.nms == "soft_nms": - bboxes, det_probs, class_idxs, class_probs = soft_nms(bboxes, det_probs, class_idxs, class_probs, sigma=args.soft_nms_sigma, thresh=args.detection_threshold) + bboxes, det_probs, class_idxs, class_probs = soft_nms(bboxes, det_probs, class_idxs, class_probs, sigma=args.soft_nms_sigma, thresh=detection_threshold) elif args.nms == "nms": bboxes, det_probs, class_idxs, class_probs = nms(bboxes, det_probs, class_idxs, class_probs, iou_thresh=args.nms_thresh) @@ -423,10 +423,10 @@ def predict_and_generate_manifest(model, dataloader_dict, args, verbose = True): for fn in dataloader_dict: fwd_detections, fwd_regressions, fwd_classifications, bck_detections, bck_regressions, bck_classifications = generate_predictions(model, dataloader_dict[fn], args, verbose=verbose) - fwd_predictions_fp = export_to_selection_table(fwd_detections, fwd_regressions, fwd_classifications, fn, args, is_bck=False, verbose=verbose) + fwd_predictions_fp = export_to_selection_table(fwd_detections, fwd_regressions, fwd_classifications, fn, args, is_bck=False, verbose=verbose, detection_threshold=args.detection_threshold) if model.is_bidirectional: assert all(x is not None for x in [bck_detections, bck_classifications, bck_regressions]) - bck_predictions_fp = export_to_selection_table(bck_detections, bck_regressions, bck_classifications, fn, args, is_bck=True, verbose=verbose) + bck_predictions_fp = export_to_selection_table(bck_detections, bck_regressions, bck_classifications, fn, args, is_bck=True, verbose=verbose, detection_threshold=args.detection_threshold) else: assert all(x is None for x in [bck_detections, bck_classifications, bck_regressions]) bck_predictions_fp = None diff --git a/voxaboxen/inference/inference.py b/voxaboxen/inference/inference.py index f223587..adfe9ce 100644 --- a/voxaboxen/inference/inference.py +++ b/voxaboxen/inference/inference.py @@ -6,54 +6,63 @@ from voxaboxen.inference.params import parse_inference_args from voxaboxen.training.params import load_params from voxaboxen.model.model import DetectionModel, DetectionModelStereo -from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table +from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table, combine_fwd_bck_preds from voxaboxen.data.data import get_single_clip_data device = "cuda" if torch.cuda.is_available() else "cpu" def inference(inference_args): inference_args = parse_inference_args(inference_args) - args = load_params(inference_args.model_args_fp) + args = load_params(inference_args.model_args_fp) files_to_infer = pd.read_csv(inference_args.file_info_for_inference) - + output_dir = os.path.join(args.experiment_dir, 'inference') if not os.path.exists(output_dir): - os.makedirs(output_dir) - - # model - if hasattr(args,'stereo') and args.stereo: - model = DetectionModelStereo(args) - else: - model = DetectionModel(args) - model_checkpoint_fp = os.path.join(args.experiment_dir, "model.pt") + os.makedirs(output_dir) + + # model + #if hasattr(args,'stereo') and args.stereo: + #model = DetectionModelStereo(args) + #else: + model = DetectionModel(args) + model_checkpoint_fp = os.path.join(args.experiment_dir, "final-model.pt") print(f"Loading model weights from {model_checkpoint_fp}") cp = torch.load(model_checkpoint_fp) - model.load_state_dict(cp["model_state_dict"]) + model.load_state_dict(cp) model = model.to(device) - + for i, row in files_to_infer.iterrows(): audio_fp = row['audio_fp'] fn = row['fn'] - + if not os.path.exists(audio_fp): print(f"Could not locate file {audio_fp}") continue - + try: dataloader = get_single_clip_data(audio_fp, args.clip_duration/2, args) except: print(f"Could not load file {audio_fp}") continue - + if len(dataloader) == 0: print(f"Skipping {fn} because it is too short") continue - - detections, regressions, classifications = generate_predictions(model, dataloader, args, verbose = True) - - target_fp = export_to_selection_table(detections, regressions, classifications, fn, args, verbose=True, target_dir=output_dir, detection_threshold = inference_args.detection_threshold, classification_threshold = inference_args.classification_threshold) - - print(f"Saving predictions for {fn} to {target_fp}") + + if inference_args.disable_bidirectional and not model.is_bidirectional: + print('Warning: you have passed the disable-bidirectional arg but model is not is_bidirectional') + detections, regressions, classifs, rev_detections, rev_regressions, rev_classifs = generate_predictions(model, dataloader, args, verbose = True) + fwd_target_fp = export_to_selection_table(detections, regressions, classifs, fn, args, is_bck=False, verbose=True, target_dir=output_dir, detection_threshold=inference_args.detection_threshold, classif_threshold=inference_args.classification_threshold) + if model.is_bidirectional and not inference_args.disable_bidirectional: + rev_target_fp = export_to_selection_table(rev_detections, rev_regressions, rev_classifs, fn, args, is_bck=True, verbose=True, target_dir=output_dir, detection_threshold=inference_args.detection_threshold, classif_threshold=inference_args.classification_threshold) + comb_target_fp, match_target_fp = combine_fwd_bck_preds(args.experiment_output_dir, fn, comb_iou_threshold=args.comb_iou_threshold, comb_discard_threshold=model.comb_discard_thresh.item()) + print(f"Saving predictions for {fn} to {comb_target_fp}") + + + #preds_manifest = predict_and_generate_manifest(model, dataloader_dict + + else: + print(f"Saving predictions for {fn} to {fwd_target_fp}") if __name__ == "__main__": main(sys.argv[1:]) diff --git a/voxaboxen/inference/params.py b/voxaboxen/inference/params.py index e6c5274..b54b3c2 100644 --- a/voxaboxen/inference/params.py +++ b/voxaboxen/inference/params.py @@ -5,11 +5,12 @@ def parse_inference_args(inference_args): parser = argparse.ArgumentParser() - + parser.add_argument('--model-args-fp', type=str, required=True, help = "filepath of model params saved as a yaml") parser.add_argument('--file-info-for-inference', type=str, required=True, help = "filepath of info csv listing filenames and filepaths of audio for inference") parser.add_argument('--detection-threshold', type=float, default=0.5, help="detection peaks need to be at or above this threshold to make it into the exported selection table") parser.add_argument('--classification-threshold', type=float, default=0.0, help="classification probability needs to be at or above this threshold to not be labeled as Unknown") - - inference_args = parser.parse_args(inference_args) + parser.add_argument('--disable-bidirectional', action='store_true') + + inference_args = parser.parse_args(inference_args) return inference_args diff --git a/voxaboxen/model/model.py b/voxaboxen/model/model.py index 5f8a6a8..6d472de 100644 --- a/voxaboxen/model/model.py +++ b/voxaboxen/model/model.py @@ -53,6 +53,7 @@ def __init__(self, args, embedding_dim=768): self.args = args aves_sr = args.sr // args.scale_factor self.detection_head = DetectionHead(args, embedding_dim = embedding_dim) + self.comb_discard_thresh = nn.Parameter(torch.tensor(0.)) if self.is_bidirectional: self.rev_detection_head = DetectionHead(args, embedding_dim = embedding_dim) diff --git a/voxaboxen/training/train_model.py b/voxaboxen/training/train_model.py index 7db1315..2f68b92 100644 --- a/voxaboxen/training/train_model.py +++ b/voxaboxen/training/train_model.py @@ -1,12 +1,12 @@ +import torch.nn as nn import pandas as pd -from voxaboxen.data.data import get_test_dataloader, get_val_dataloader import torch -from voxaboxen.model.model import DetectionModel, DetectionModelStereo +from voxaboxen.data.data import get_test_dataloader, get_val_dataloader +from voxaboxen.model.model import DetectionModel from voxaboxen.training.train import train from voxaboxen.training.params import parse_args, set_seed, save_params -from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table, get_metrics, summarize_metrics, predict_and_generate_manifest, evaluate_based_on_manifest +from voxaboxen.evaluation.evaluation import predict_and_generate_manifest, evaluate_based_on_manifest -import yaml import sys import os @@ -36,9 +36,9 @@ def train_model(args): save_params(args) model = DetectionModel(args) - if args.reload_from is not None: - checkpoint = torch.load(os.path.join(args.project_dir, args.reload_from, 'model.pt')) - model.load_state_dict(checkpoint['model_state_dict']) + #if args.reload_from is not None: + #checkpoint = torch.load(os.path.join(args.project_dir, args.reload_from, 'model.pt')) + #model.load_state_dict(checkpoint['model_state_dict']) ## Training trained_model = train(model, args) @@ -49,14 +49,13 @@ def train_model(args): val_manifest = predict_and_generate_manifest(trained_model, val_dataloader, args) - model.comb_discard_thresh = -1 if model.is_bidirectional: best_f1 = 0 for comb_discard_thresh in [.3,.35,.4,.45,.5,.55,.6,.65,.75,.8,.85,.9]: val_metrics, val_conf_mats = evaluate_based_on_manifest(val_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=0.5, class_threshold=0.5, comb_discard_threshold=comb_discard_thresh) new_f1 = val_metrics['comb']['macro']['f1'] if new_f1 > best_f1: - model.comb_discard_thresh = comb_discard_thresh + model.comb_discard_thresh = nn.Parameter(torch.tensor(comb_discard_thresh)) best_f1 = new_f1 print(f'IOU: 0.5 class_thresh: 0.5 Comb discard threshold: {comb_discard_thresh}') print_metrics(val_metrics, just_one_label=(len(args.label_set)==1)) @@ -64,10 +63,12 @@ def train_model(args): test_manifest = predict_and_generate_manifest(trained_model, test_dataloader, args) for iou in [0.2, 0.5, 0.8]: - test_metrics, test_conf_mats = evaluate_based_on_manifest(test_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=iou, class_threshold=0.5, comb_discard_threshold=model.comb_discard_thresh) + test_metrics, test_conf_mats = evaluate_based_on_manifest(test_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=iou, class_threshold=0.5, comb_discard_threshold=model.comb_discard_thresh.item()) print(f'Test with IOU{iou}') print_metrics(test_metrics, just_one_label=(len(args.label_set)==1)) + torch.save(model.state_dict(), os.path.join(args.experiment_dir, 'final-model.pt')) + if __name__ == "__main__": train_model(sys.argv[1:])