Skip to content

Commit

Permalink
pass cl-arg to run bidirectional, stereo and/or segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Lou1sM committed May 17, 2024
1 parent 0a12448 commit 7312733
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 81 deletions.
61 changes: 36 additions & 25 deletions voxaboxen/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,29 @@ def generate_predictions(model, single_clip_dataloader, args, verbose = True):
X = X.to(device = device, dtype = torch.float)
X, _, _, _ = rms_and_mixup(X, None, None, None, False, args)

detection, regression, classif, rev_detection, rev_regression, rev_classif = model(X)
classif = torch.nn.functional.softmax(classif, dim=-1)
rev_classif = torch.nn.functional.softmax(rev_classif, dim=-1)

all_detections.append(detection)
all_regressions.append(regression)
all_classifs.append(classif)
all_rev_detections.append(rev_detection)
all_rev_regressions.append(rev_regression)
all_rev_classifs.append(rev_classif)
model_outputs = model(X)
assert isinstance(model_outputs, tuple)
all_detections.append(model_outputs[0])
all_regressions.append(model_outputs[1])
all_classifs.append(model_outputs[2].softmax(-1))
if model.is_bidirectional:
assert all(x is not None for x in model_outputs)
all_rev_detections.append(model_outputs[3])
all_rev_regressions.append(model_outputs[4])
all_rev_classifs.append(model_outputs[5].softmax(-1))
else:
assert all(x is None for x in model_outputs[3:])

if args.is_test and i==15:
break

all_detections = torch.cat(all_detections)
all_regressions = torch.cat(all_regressions)
all_classifs = torch.cat(all_classifs)
all_rev_detections = torch.cat(all_rev_detections)
all_rev_regressions = torch.cat(all_rev_regressions)
all_rev_classifs = torch.cat(all_rev_classifs)

if model.is_bidirectional:
all_rev_detections = torch.cat(all_rev_detections)
all_rev_regressions = torch.cat(all_rev_regressions)
all_rev_classifs = torch.cat(all_rev_classifs)

######## Todo: Need better checking that preds are the correct dur
assert all_detections.size(dim=1) % 2 == 0
Expand Down Expand Up @@ -186,7 +188,10 @@ def assemble(d, r, c):
return assembled_d.detach().cpu().numpy(), assembled_r.detach().cpu().numpy(), assembled_c.detach().cpu().numpy(),

assembled_dets, assembled_regs, assembled_classifs = assemble(all_detections, all_regressions, all_classifs)
assembled_rev_dets, assembled_rev_regs, assembled_rev_classifs = assemble(all_rev_detections, all_rev_regressions, all_rev_classifs)
if model.is_bidirectional:
assembled_rev_dets, assembled_rev_regs, assembled_rev_classifs = assemble(all_rev_detections, all_rev_regressions, all_rev_classifs)
else:
assembled_rev_dets = assembled_rev_regs = assembled_rev_classifs = None
return assembled_dets, assembled_regs, assembled_classifs, assembled_rev_dets, assembled_rev_regs, assembled_rev_classifs

def generate_features(model, single_clip_dataloader, args, verbose = True):
Expand Down Expand Up @@ -419,7 +424,12 @@ def predict_and_generate_manifest(model, dataloader_dict, args, verbose = True):
fwd_detections, fwd_regressions, fwd_classifications, bck_detections, bck_regressions, bck_classifications = generate_predictions(model, dataloader_dict[fn], args, verbose=verbose)

fwd_predictions_fp = export_to_selection_table(fwd_detections, fwd_regressions, fwd_classifications, fn, args, is_bck=False, verbose=verbose)
bck_predictions_fp = export_to_selection_table(bck_detections, bck_regressions, bck_classifications, fn, args, is_bck=True, verbose=verbose)
if model.is_bidirectional:
assert all(x is not None for x in [bck_detections, bck_classifications, bck_regressions])
bck_predictions_fp = export_to_selection_table(bck_detections, bck_regressions, bck_classifications, fn, args, is_bck=True, verbose=verbose)
else:
assert all(x is None for x in [bck_detections, bck_classifications, bck_regressions])
bck_predictions_fp = None
annotations_fp = dataloader_dict[fn].dataset.annot_fp

fns.append(fn)
Expand All @@ -431,28 +441,29 @@ def predict_and_generate_manifest(model, dataloader_dict, args, verbose = True):
return manifest

def evaluate_based_on_manifest(manifest, args, output_dir, iou, class_threshold, comb_discard_threshold):
pred_types = ('fwd', 'bck', 'comb', 'match')
pred_types = ('fwd', 'bck', 'comb', 'match') if args.bidirectional else ('fwd',)
metrics = {p:{} for p in pred_types}
conf_mats = {p:{} for p in pred_types}
conf_mat_labels = {}

for i, row in manifest.iterrows():
fn = row['filename']
annots_fp = row['annotations_fp']
row['comb_predictions_fp'], row['match_predictions_fp'] = combine_fwd_bck_preds(args.experiment_output_dir, fn, comb_iou_threshold=args.comb_iou_threshold, comb_discard_threshold=comb_discard_threshold)
fn = row['filename']
annots_fp = row['annotations_fp']
if args.bidirectional:
row['comb_predictions_fp'], row['match_predictions_fp'] = combine_fwd_bck_preds(args.experiment_output_dir, fn, comb_iou_threshold=args.comb_iou_threshold, comb_discard_threshold=comb_discard_threshold)

for pred_type in pred_types:
preds_fp = row[f'{pred_type}_predictions_fp']
metrics[pred_type][fn] = get_metrics(preds_fp, annots_fp, args, iou, class_threshold)
conf_mats[pred_type][fn], conf_mat_labels[pred_type] = get_confusion_matrix(preds_fp, annots_fp, args, iou, class_threshold)
for pred_type in pred_types:
preds_fp = row[f'{pred_type}_predictions_fp']
metrics[pred_type][fn] = get_metrics(preds_fp, annots_fp, args, iou, class_threshold)
conf_mats[pred_type][fn], conf_mat_labels[pred_type] = get_confusion_matrix(preds_fp, annots_fp, args, iou, class_threshold)

if output_dir is not None:
if not os.path.exists(output_dir):
os.makedirs(output_dir)

# summarize and save metrics
conf_mat_summaries = {}
for pred_type in ('fwd', 'bck', 'comb', 'match'):
for pred_type in pred_types:
summary = summarize_metrics(metrics[pred_type])
metrics[pred_type]['summary'] = summary
metrics[pred_type]['macro'] = macro_metrics(summary)
Expand Down
23 changes: 18 additions & 5 deletions voxaboxen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,17 @@ def unfreeze(self):
class DetectionModel(nn.Module):
def __init__(self, args, embedding_dim=768):
super().__init__()
self.is_bidirectional = args.bidirectional
self.is_stereo = args.stereo
self.is_segmentation = args.segmentation
if self.is_stereo:
embedding_dim *= 2
self.encoder = AvesEmbedding(args)
self.args = args
aves_sr = args.sr // args.scale_factor
self.detection_head = DetectionHead(args, embedding_dim = embedding_dim)
self.rev_detection_head = DetectionHead(args, embedding_dim = embedding_dim)
if self.is_bidirectional:
self.rev_detection_head = DetectionHead(args, embedding_dim = embedding_dim)

def forward(self, x):
"""
Expand All @@ -64,7 +70,12 @@ def forward(self, x):
expected_dur_output = math.ceil(x.size(1)/self.args.scale_factor)

x = x-torch.mean(x,axis=1,keepdim=True)
feats = self.encoder(x)
if self.is_stereo:
feats0 = self.encoder(x[:,0,:])
feats1 = self.encoder(x[:,1,:])
feats = torch.cat([feats0,feats1],dim=-1)
else:
feats = self.encoder(x)

#aves may be off by 1 sample from expected
pad = expected_dur_output - feats.size(1)
Expand All @@ -73,8 +84,11 @@ def forward(self, x):

detection_logits, regression, class_logits = self.detection_head(feats)
detection_probs = torch.sigmoid(detection_logits)
rev_detection_logits, rev_regression, rev_class_logits = self.rev_detection_head(feats)
rev_detection_probs = torch.sigmoid(rev_detection_logits)
if self.is_bidirectional:
rev_detection_logits, rev_regression, rev_class_logits = self.rev_detection_head(feats)
rev_detection_probs = torch.sigmoid(rev_detection_logits)
else:
rev_detection_probs = rev_regression = rev_class_logits = None

return detection_probs, regression, class_logits, rev_detection_probs, rev_regression, rev_class_logits

Expand Down Expand Up @@ -161,7 +175,6 @@ def forward(self, x):

return detection_probs, regression, class_logits


def rms_and_mixup(X, d, r, y, train, args):
if args.rms_norm:
ms = torch.mean(X ** 2, dim = -1, keepdim = True)
Expand Down
5 changes: 5 additions & 0 deletions voxaboxen/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def parse_args(args,allow_unknown=False):
parser.add_argument('--num-workers', type=int, default=8)

# Model
parser.add_argument('--bidirectional', action='store_true', help="train and inference in both directions and combine results")
parser.add_argument('--segmentation', action='store_true')
parser.add_argument('--sr', type=int, default=16000)
parser.add_argument('--scale-factor', type=int, default = 320, help = "downscaling performed by aves")
parser.add_argument('--aves-model-weight-fp', type=str, default = "weights/aves-base-bio.torchaudio.pt")
Expand Down Expand Up @@ -77,6 +79,9 @@ def parse_args(args,allow_unknown=False):
if args.clip_hop is None:
setattr(args, "clip_hop", args.clip_duration/2)

if args.bidirectional and args.segmentation:
raise ValueError("bidirectional and segmentation settings are not currently compatible")

if allow_unknown:
return args, remaining
else:
Expand Down
73 changes: 42 additions & 31 deletions voxaboxen/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def train(model, args):
if use_val:
eval_scores = val_epoch(model, t, val_dataloader, args)
# TODO: maybe plot evals for other pred_types
val_evals.append(eval_scores['comb'].copy())
val_evals.append(eval_scores['fwd'].copy())
plot_eval(train_evals, learning_rates, args, val_evals=val_evals)

val_evals_by_epoch = {i : e for i, e in enumerate(val_evals)}
Expand All @@ -77,7 +77,7 @@ def train(model, args):
scheduler.step()

if use_val and args.early_stopping:
current_f1 = eval_scores['comb']['f1']
current_f1 = eval_scores['comb']['f1'] if model.is_bidirectional else eval_scores['fwd']['f1']
if current_f1 > best_f1:
print('found new best model')
best_f1 = current_f1
Expand Down Expand Up @@ -119,7 +119,7 @@ def train(model, args):

# resave validation with best model
if use_val:
val_epoch(model, t+1, val_dataloader, args)
val_epoch(model, args.n_epochs, val_dataloader, args)

return model

Expand Down Expand Up @@ -153,52 +153,63 @@ def train_epoch(model, t, dataloader, detection_loss_fn, reg_loss_fn, class_loss


evals = {}
normal_train_loss = 0; normal_losses = []; detection_losses = []; regression_losses = []; class_losses = []
train_loss = 0; losses = []; detection_losses = []; regression_losses = []; class_losses = []
rev_train_loss = 0; rev_losses = []; rev_detection_losses = []; rev_regression_losses = []; rev_class_losses = []
train_loss = 0; losses = []

data_iterator = tqdm.tqdm(dataloader)
for i, (X, d, r, y, rev_d, rev_r, rev_y) in enumerate(data_iterator):
#for i, (X, d, r, y, rev_d, rev_r, rev_y) in enumerate(data_iterator):
for i, batch in enumerate(data_iterator):
num_batches_seen = i
X = X.to(device = device, dtype = torch.float)
d = d.to(device = device, dtype = torch.float)
r = r.to(device = device, dtype = torch.float)
y = y.to(device = device, dtype = torch.float)
rev_d = rev_d.to(device = device, dtype = torch.float)
rev_r = rev_r.to(device = device, dtype = torch.float)
rev_y = rev_y.to(device = device, dtype = torch.float)

X, d, r, y = rms_and_mixup(X, d, r, y, True, args)
_, rev_d, rev_r, rev_y = rms_and_mixup(X, rev_d, rev_r, rev_y, True, args)
probs, regression, class_logits, rev_probs, rev_regression, rev_class_logits = model(X)

batch = [item.to(device, dtype=torch.float) for item in batch]
X, d, r, y = batch[:4]
#X = X.to(device = device, dtype = torch.float)
#d = d.to(device = device, dtype = torch.float)
#r = r.to(device = device, dtype = torch.float)
#y = y.to(device = device, dtype = torch.float)
# We mask out loss from each end of the clip, so the model isn't forced to learn to detect events that are partially cut off.
# This does not affect inference, because during inference we overlap clips at 50%

X, d, r, y = rms_and_mixup(X, d, r, y, True, args)
probs, regression, class_logits, rev_probs, rev_regression, rev_class_logits = model(X)
#model_outputs = model(X)
#probs, regression, class_logits = model_outputs[:3]
detection_loss, reg_loss, class_loss = lf(d, probs, r, regression, y, class_logits, args=args, reg_loss_fn=reg_loss_fn, class_loss_fn=class_loss_fn)
rev_detection_loss, rev_reg_loss, rev_class_loss = lf(rev_d, rev_probs, rev_r, rev_regression, rev_y, rev_class_logits, args=args, reg_loss_fn=reg_loss_fn, class_loss_fn=class_loss_fn)
normal_loss = args.rho * class_loss + detection_loss + args.lamb * reg_loss
rev_loss = args.rho * rev_class_loss + rev_detection_loss + args.lamb * rev_reg_loss
loss = (normal_loss + rev_loss)/2

loss = args.rho * class_loss + detection_loss + args.lamb * reg_loss
train_loss += loss.item()
rev_train_loss += rev_loss.item()
normal_train_loss += normal_loss.item()
normal_losses.append(normal_loss.item())
rev_losses.append(rev_loss.item())
losses.append(loss.item())
detection_losses.append(detection_loss.item())
regression_losses.append(args.lamb * reg_loss.item())
class_losses.append(args.rho * class_loss.item())
rev_detection_losses.append(rev_detection_loss.item())
rev_regression_losses.append(args.lamb * rev_reg_loss.item())
rev_class_losses.append(args.rho * rev_class_loss.item())

pbar_str = f"loss {np.mean(losses[-10:]):.5f}, det {np.mean(detection_losses[-10:]):.5f}, reg {np.mean(regression_losses[-10:]):.5f}, class {np.mean(class_losses[-10:]):.5f}"

if model.is_bidirectional:
assert all(x is not None for x in [rev_probs, rev_regression, rev_class_logits])
rev_d, rev_r, rev_y = batch[4:]
#rev_probs, rev_regression, rev_class_logits = model_outputs[3:]
_, rev_d, rev_r, rev_y = rms_and_mixup(X, rev_d, rev_r, rev_y, True, args)


rev_detection_loss, rev_reg_loss, rev_class_loss = lf(rev_d, rev_probs, rev_r, rev_regression, rev_y, rev_class_logits, args=args, reg_loss_fn=reg_loss_fn, class_loss_fn=class_loss_fn)
rev_loss = args.rho * rev_class_loss + rev_detection_loss + args.lamb * rev_reg_loss
rev_train_loss += rev_loss.item()
rev_losses.append(rev_loss.item())
rev_detection_losses.append(rev_detection_loss.item())
rev_regression_losses.append(args.lamb * rev_reg_loss.item())
rev_class_losses.append(args.rho * rev_class_loss.item())
loss = (loss + rev_loss)/2

pbar_str += f" revloss {np.mean(rev_losses[-10:]):.5f}, revdet {np.mean(rev_detection_losses[-10:]):.5f}, revreg {np.mean(rev_regression_losses[-10:]):.5f}, revclass {np.mean(rev_class_losses[-10:]):.5f}"
else:
assert all(x is None for x in [rev_probs, rev_regression, rev_class_logits])


optimizer.zero_grad()
loss.backward()

optimizer.step()
if i > 10:
data_iterator.set_description(f"loss {np.mean(losses[-10:]):.5f}, det {np.mean(detection_losses[-10:]):.5f}, reg {np.mean(regression_losses[-10:]):.5f}, class {np.mean(class_losses[-10:]):.5f} revloss {np.mean(rev_losses[-10:]):.5f}, revdet {np.mean(rev_detection_losses[-10:]):.5f}, revreg {np.mean(rev_regression_losses[-10:]):.5f}, revclass {np.mean(rev_class_losses[-10:]):.5f}")
data_iterator.set_description(pbar_str)

if args.is_test and i == 15: break

Expand Down
35 changes: 15 additions & 20 deletions voxaboxen/training/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,37 @@ def train_model(args):
os.makedirs(args.experiment_output_dir)

save_params(args)
if hasattr(args,'stereo') and args.stereo:
model = DetectionModelStereo(args)
else:
model = DetectionModel(args)
model = DetectionModel(args)

if args.reload_from is not None:
checkpoint = torch.load(os.path.join(args.project_dir, args.reload_from, 'model.pt'))
model.load_state_dict(checkpoint['model_state_dict'])

## Training
if args.n_epochs == 0:
trained_model = model
else:
trained_model = train(model, args)
trained_model = train(model, args)

## Evaluation
test_dataloader = get_test_dataloader(args)
val_dataloader = get_val_dataloader(args)

val_manifest = predict_and_generate_manifest(trained_model, val_dataloader, args)

best_comb_discard_thresh = -1
best_f1 = 0
for comb_discard_thresh in [.3,.35,.4,.45,.5,.55,.6,.65,.75,.8,.85,.9]:
val_metrics, val_conf_mats = evaluate_based_on_manifest(val_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=0.5, class_threshold=0.5, comb_discard_threshold=comb_discard_thresh)
new_f1 = val_metrics['comb']['macro']['f1']
if new_f1 > best_f1:
best_comb_discard_thresh = comb_discard_thresh
best_f1 = new_f1
print(f'IOU: 0.5 class_thresh: 0.5 Comb discard threshold: {comb_discard_thresh}')
print_metrics(val_metrics, just_one_label=(len(args.label_set)==1))
model.comb_discard_thresh = -1
if model.is_bidirectional:
best_f1 = 0
for comb_discard_thresh in [.3,.35,.4,.45,.5,.55,.6,.65,.75,.8,.85,.9]:
val_metrics, val_conf_mats = evaluate_based_on_manifest(val_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=0.5, class_threshold=0.5, comb_discard_threshold=comb_discard_thresh)
new_f1 = val_metrics['comb']['macro']['f1']
if new_f1 > best_f1:
model.comb_discard_thresh = comb_discard_thresh
best_f1 = new_f1
print(f'IOU: 0.5 class_thresh: 0.5 Comb discard threshold: {comb_discard_thresh}')
print_metrics(val_metrics, just_one_label=(len(args.label_set)==1))
print(f'Using comb_discard_thresh: {model.comb_discard_thresh}')

test_manifest = predict_and_generate_manifest(trained_model, test_dataloader, args)
print(f'Using thresh: {best_comb_discard_thresh}')
for iou in [0.2, 0.5, 0.8]:
test_metrics, test_conf_mats = evaluate_based_on_manifest(test_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=iou, class_threshold=0.5, comb_discard_threshold=best_comb_discard_thresh)
test_metrics, test_conf_mats = evaluate_based_on_manifest(test_manifest, args, output_dir = os.path.join(args.experiment_dir, 'test_results') , iou=iou, class_threshold=0.5, comb_discard_threshold=model.comb_discard_thresh)
print(f'Test with IOU{iou}')
print_metrics(test_metrics, just_one_label=(len(args.label_set)==1))

Expand Down

0 comments on commit 7312733

Please sign in to comment.