Skip to content

Commit

Permalink
handle loading of maybe-bidirectional model at inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Lou1sM committed May 18, 2024
1 parent 7312733 commit cfcfe50
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 40 deletions.
10 changes: 5 additions & 5 deletions voxaboxen/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def generate_features(model, single_clip_dataloader, args, verbose = True):

return all_features.detach().cpu().numpy()

def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=True, target_dir=None, classif_threshold=0):
def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=True, target_dir=None, detection_threshold=0, classif_threshold=0):

if target_dir is None:
target_dir = args.experiment_output_dir
Expand All @@ -248,7 +248,7 @@ def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=Tr
# np.save(target_fp, classifs)

## peaks
det_peaks, properties = find_peaks(dets, height=args.detection_threshold, distance=args.peak_distance)
det_peaks, properties = find_peaks(dets, height=detection_threshold, distance=args.peak_distance)
det_probs = properties['peak_heights']

## regs and classifs
Expand Down Expand Up @@ -278,7 +278,7 @@ def export_to_selection_table(dets, regs, classifs, fn, args, is_bck, verbose=Tr
bboxes, det_probs, class_idxs, class_probs = pred2bbox(det_peaks, det_probs, durations, class_idxs, class_probs, pred_sr, is_bck)

if args.nms == "soft_nms":
bboxes, det_probs, class_idxs, class_probs = soft_nms(bboxes, det_probs, class_idxs, class_probs, sigma=args.soft_nms_sigma, thresh=args.detection_threshold)
bboxes, det_probs, class_idxs, class_probs = soft_nms(bboxes, det_probs, class_idxs, class_probs, sigma=args.soft_nms_sigma, thresh=detection_threshold)
elif args.nms == "nms":
bboxes, det_probs, class_idxs, class_probs = nms(bboxes, det_probs, class_idxs, class_probs, iou_thresh=args.nms_thresh)

Expand Down Expand Up @@ -423,10 +423,10 @@ def predict_and_generate_manifest(model, dataloader_dict, args, verbose = True):
for fn in dataloader_dict:
fwd_detections, fwd_regressions, fwd_classifications, bck_detections, bck_regressions, bck_classifications = generate_predictions(model, dataloader_dict[fn], args, verbose=verbose)

fwd_predictions_fp = export_to_selection_table(fwd_detections, fwd_regressions, fwd_classifications, fn, args, is_bck=False, verbose=verbose)
fwd_predictions_fp = export_to_selection_table(fwd_detections, fwd_regressions, fwd_classifications, fn, args, is_bck=False, verbose=verbose, detection_threshold=args.detection_threshold)
if model.is_bidirectional:
assert all(x is not None for x in [bck_detections, bck_classifications, bck_regressions])
bck_predictions_fp = export_to_selection_table(bck_detections, bck_regressions, bck_classifications, fn, args, is_bck=True, verbose=verbose)
bck_predictions_fp = export_to_selection_table(bck_detections, bck_regressions, bck_classifications, fn, args, is_bck=True, verbose=verbose, detection_threshold=args.detection_threshold)
else:
assert all(x is None for x in [bck_detections, bck_classifications, bck_regressions])
bck_predictions_fp = None
Expand Down
53 changes: 31 additions & 22 deletions voxaboxen/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,63 @@
from voxaboxen.inference.params import parse_inference_args
from voxaboxen.training.params import load_params
from voxaboxen.model.model import DetectionModel, DetectionModelStereo
from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table
from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table, combine_fwd_bck_preds
from voxaboxen.data.data import get_single_clip_data

device = "cuda" if torch.cuda.is_available() else "cpu"

def inference(inference_args):
inference_args = parse_inference_args(inference_args)
args = load_params(inference_args.model_args_fp)
args = load_params(inference_args.model_args_fp)
files_to_infer = pd.read_csv(inference_args.file_info_for_inference)

output_dir = os.path.join(args.experiment_dir, 'inference')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# model
if hasattr(args,'stereo') and args.stereo:
model = DetectionModelStereo(args)
else:
model = DetectionModel(args)
model_checkpoint_fp = os.path.join(args.experiment_dir, "model.pt")
os.makedirs(output_dir)

# model
#if hasattr(args,'stereo') and args.stereo:
#model = DetectionModelStereo(args)
#else:
model = DetectionModel(args)
model_checkpoint_fp = os.path.join(args.experiment_dir, "final-model.pt")
print(f"Loading model weights from {model_checkpoint_fp}")
cp = torch.load(model_checkpoint_fp)
model.load_state_dict(cp["model_state_dict"])
model.load_state_dict(cp)
model = model.to(device)

for i, row in files_to_infer.iterrows():
audio_fp = row['audio_fp']
fn = row['fn']

if not os.path.exists(audio_fp):
print(f"Could not locate file {audio_fp}")
continue

try:
dataloader = get_single_clip_data(audio_fp, args.clip_duration/2, args)
except:
print(f"Could not load file {audio_fp}")
continue

if len(dataloader) == 0:
print(f"Skipping {fn} because it is too short")
continue

detections, regressions, classifications = generate_predictions(model, dataloader, args, verbose = True)

target_fp = export_to_selection_table(detections, regressions, classifications, fn, args, verbose=True, target_dir=output_dir, detection_threshold = inference_args.detection_threshold, classification_threshold = inference_args.classification_threshold)

print(f"Saving predictions for {fn} to {target_fp}")

if inference_args.disable_bidirectional and not model.is_bidirectional:
print('Warning: you have passed the disable-bidirectional arg but model is not is_bidirectional')
detections, regressions, classifs, rev_detections, rev_regressions, rev_classifs = generate_predictions(model, dataloader, args, verbose = True)
fwd_target_fp = export_to_selection_table(detections, regressions, classifs, fn, args, is_bck=False, verbose=True, target_dir=output_dir, detection_threshold=inference_args.detection_threshold, classif_threshold=inference_args.classification_threshold)
if model.is_bidirectional and not inference_args.disable_bidirectional:
rev_target_fp = export_to_selection_table(rev_detections, rev_regressions, rev_classifs, fn, args, is_bck=True, verbose=True, target_dir=output_dir, detection_threshold=inference_args.detection_threshold, classif_threshold=inference_args.classification_threshold)
comb_target_fp, match_target_fp = combine_fwd_bck_preds(args.experiment_output_dir, fn, comb_iou_threshold=args.comb_iou_threshold, comb_discard_threshold=model.comb_discard_thresh.item())
print(f"Saving predictions for {fn} to {comb_target_fp}")


#preds_manifest = predict_and_generate_manifest(model, dataloader_dict

else:
print(f"Saving predictions for {fn} to {fwd_target_fp}")

if __name__ == "__main__":
main(sys.argv[1:])
Expand Down
7 changes: 4 additions & 3 deletions voxaboxen/inference/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

def parse_inference_args(inference_args):
parser = argparse.ArgumentParser()

parser.add_argument('--model-args-fp', type=str, required=True, help = "filepath of model params saved as a yaml")
parser.add_argument('--file-info-for-inference', type=str, required=True, help = "filepath of info csv listing filenames and filepaths of audio for inference")
parser.add_argument('--detection-threshold', type=float, default=0.5, help="detection peaks need to be at or above this threshold to make it into the exported selection table")
parser.add_argument('--classification-threshold', type=float, default=0.0, help="classification probability needs to be at or above this threshold to not be labeled as Unknown")

inference_args = parser.parse_args(inference_args)
parser.add_argument('--disable-bidirectional', action='store_true')

inference_args = parser.parse_args(inference_args)
return inference_args
1 change: 1 addition & 0 deletions voxaboxen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, args, embedding_dim=768):
self.args = args
aves_sr = args.sr // args.scale_factor
self.detection_head = DetectionHead(args, embedding_dim = embedding_dim)
self.comb_discard_thresh = nn.Parameter(torch.tensor(0.))
if self.is_bidirectional:
self.rev_detection_head = DetectionHead(args, embedding_dim = embedding_dim)

Expand Down
21 changes: 11 additions & 10 deletions voxaboxen/training/train_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch.nn as nn
import pandas as pd
from voxaboxen.data.data import get_test_dataloader, get_val_dataloader
import torch
from voxaboxen.model.model import DetectionModel, DetectionModelStereo
from voxaboxen.data.data import get_test_dataloader, get_val_dataloader
from voxaboxen.model.model import DetectionModel
from voxaboxen.training.train import train
from voxaboxen.training.params import parse_args, set_seed, save_params
from voxaboxen.evaluation.evaluation import generate_predictions, export_to_selection_table, get_metrics, summarize_metrics, predict_and_generate_manifest, evaluate_based_on_manifest
from voxaboxen.evaluation.evaluation import predict_and_generate_manifest, evaluate_based_on_manifest

import yaml
import sys
import os

Expand Down Expand Up @@ -36,9 +36,9 @@ def train_model(args):
save_params(args)
model = DetectionModel(args)

if args.reload_from is not None:
checkpoint = torch.load(os.path.join(args.project_dir, args.reload_from, 'model.pt'))
model.load_state_dict(checkpoint['model_state_dict'])
#if args.reload_from is not None:
#checkpoint = torch.load(os.path.join(args.project_dir, args.reload_from, 'model.pt'))
#model.load_state_dict(checkpoint['model_state_dict'])

## Training
trained_model = train(model, args)
Expand All @@ -49,25 +49,26 @@ def train_model(args):

val_manifest = predict_and_generate_manifest(trained_model, val_dataloader, args)

model.comb_discard_thresh = -1
if model.is_bidirectional:
best_f1 = 0
for comb_discard_thresh in [.3,.35,.4,.45,.5,.55,.6,.65,.75,.8,.85,.9]:
val_metrics, val_conf_mats = evaluate_based_on_manifest(val_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=0.5, class_threshold=0.5, comb_discard_threshold=comb_discard_thresh)
new_f1 = val_metrics['comb']['macro']['f1']
if new_f1 > best_f1:
model.comb_discard_thresh = comb_discard_thresh
model.comb_discard_thresh = nn.Parameter(torch.tensor(comb_discard_thresh))
best_f1 = new_f1
print(f'IOU: 0.5 class_thresh: 0.5 Comb discard threshold: {comb_discard_thresh}')
print_metrics(val_metrics, just_one_label=(len(args.label_set)==1))
print(f'Using comb_discard_thresh: {model.comb_discard_thresh}')

test_manifest = predict_and_generate_manifest(trained_model, test_dataloader, args)
for iou in [0.2, 0.5, 0.8]:
test_metrics, test_conf_mats = evaluate_based_on_manifest(test_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=iou, class_threshold=0.5, comb_discard_threshold=model.comb_discard_thresh)
test_metrics, test_conf_mats = evaluate_based_on_manifest(test_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=iou, class_threshold=0.5, comb_discard_threshold=model.comb_discard_thresh.item())
print(f'Test with IOU{iou}')
print_metrics(test_metrics, just_one_label=(len(args.label_set)==1))

torch.save(model.state_dict(), os.path.join(args.experiment_dir, 'final-model.pt'))

if __name__ == "__main__":
train_model(sys.argv[1:])

Expand Down

0 comments on commit cfcfe50

Please sign in to comment.