Skip to content

Commit b0985ec

Browse files
committed
More detailed evaluation
1 parent 66e9813 commit b0985ec

File tree

4 files changed

+81
-33
lines changed

4 files changed

+81
-33
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ scipy~=1.10.1
99
sphinx~=7.0.1
1010
sphinx-press-theme~=0.8.0
1111
sphinxcontrib-video~=0.1.1
12-
deepdiff~=6.3.0
12+
deepdiff~=6.3.0
13+
opencv-python~=4.8.1.78

src/main_evaluation.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> argparse.Arg
3434
parser.add_argument("--n_samples",
3535
type=int,
3636
metavar="N",
37-
default=180,
37+
default=60,
3838
help="Number of samples to evaluate."
3939
)
4040
parser.add_argument('--simplified',
@@ -358,9 +358,17 @@ def analyze_interrupt_line(img: Tensor, baseline_img: Tensor, lateral_features:
358358
"""
359359
baseline_img = baseline_img[-1].squeeze()
360360
mask = torch.where(img != baseline_img, True, False).unsqueeze(0).repeat(lateral_features[-1].shape[1], 1, 1)
361-
baseline_lateral_features = baseline_lateral_features[-1].squeeze(0)[mask]
362-
lateral_features = lateral_features[-1].squeeze(0)[mask]
363-
accuracy = 1. - F.l1_loss(lateral_features, baseline_lateral_features)
361+
baseline_lateral_features_masked = baseline_lateral_features[-1].squeeze(0)[mask]
362+
lateral_features_masked = lateral_features[-1].squeeze(0)[mask]
363+
mask_active = torch.where(baseline_lateral_features_masked > 0., True, False)
364+
if baseline_lateral_features_masked[mask_active].numel() > 0:
365+
accuracy = 1. - F.l1_loss(lateral_features_masked[mask_active], baseline_lateral_features_masked[mask_active])
366+
else:
367+
accuracy = torch.tensor(0.)
368+
# mask = torch.where(img != baseline_img, True, False)
369+
# baseline_lateral_features_masked = torch.clip(torch.sum(baseline_lateral_features[-1].squeeze(0), dim=0)[mask], max=1)
370+
# lateral_features_masked = torch.clip(torch.sum(lateral_features[-1].squeeze(0), dim=0)[mask], max=1)
371+
# accuracy = 1. - F.l1_loss(lateral_features_masked, baseline_lateral_features_masked)
364372
return accuracy.item()
365373

366374

@@ -439,7 +447,6 @@ def process_data(
439447
"""
440448
Processes the data and store the network activations as video
441449
:param generator: Data generator
442-
:param eval_args: Evaluation arguments
443450
:param config: Configuration
444451
:param fabric: Fabric instance
445452
:param feature_extractor: Feature extractor
@@ -450,7 +457,7 @@ def process_data(
450457
avg_noise_meter = AverageMeter()
451458
avg_line_recon_accuracy_meter = AverageMeter()
452459
avg_recon_accuracy_meter, avg_recon_recall_meter, avg_recon_precision_meter = AverageMeter(), AverageMeter(), AverageMeter()
453-
fp = f"../tmp/v2/{config['run']['load_state_path'].split('.')[0]}_{'noise:'+str(config['noise']) if config['noise'] > 0 else 'no-noise'}_li-{config['line_interrupt']}.mp4"
460+
fp = f"../tmp/v3/{config['config']}_{'noise:'+str(config['noise']) if config['noise'] > 0 else 'no-noise'}_li-{config['line_interrupt']}.mp4"
454461
if Path(fp).exists():
455462
Path(fp).unlink()
456463
out = cv2.VideoWriter(fp, cv2.VideoWriter_fourcc(*'mp4v'), config['fps'],
@@ -470,9 +477,39 @@ def process_data(
470477
result = ci.create_image(inp[timestep], inp_features[timestep, 0], l1_act[timestep, 0], l1_act_prob[timestep, 0])
471478
out.write(cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
472479
out.release()
480+
l1_acts = torch.stack(l1_acts)
473481
if 'store_baseline_activations_path' in config and config['store_baseline_activations_path'] is not None:
474-
torch.save([torch.stack(l1_acts), torch.stack(imgs_)], config['store_baseline_activations_path'])
475-
print("Video stored at", fp)
482+
torch.save([l1_acts, torch.stack(imgs_)], config['store_baseline_activations_path'])
483+
# print("Video stored at", fp)
484+
485+
calc_overlap = True
486+
if calc_overlap:
487+
# consider final timestep
488+
l1_acts_f = l1_acts[:, -1].squeeze()
489+
490+
# calculate active / inactive cells
491+
avg_activate_cells = torch.mean(torch.sum(l1_acts_f > 0, dim=(1, 2, 3)).float())
492+
avg_inactivate_cells = torch.mean(torch.sum(l1_acts_f == 0, dim=(1, 2, 3)).float())
493+
print(f"Average activated cells in net fragments: {avg_activate_cells}")
494+
print(f"Average inactivated cells in net fragments: {avg_inactivate_cells}")
495+
496+
# calculate overlap
497+
l1_acts_f = l1_acts_f.view(l1_acts_f.shape[0], -1)
498+
overlaps, boverlaps = [], []
499+
for v in range(l1_acts_f.shape[0]):
500+
biggest_overlap = 0
501+
for v2 in range(l1_acts_f.shape[0]):
502+
if v != v2:
503+
overlap = torch.sum(l1_acts_f[v] * l1_acts_f[v2]) / torch.sum(l1_acts_f[v] + l1_acts_f[v2])
504+
overlaps.append(overlap)
505+
if overlap > biggest_overlap:
506+
biggest_overlap = overlap
507+
boverlaps.append(biggest_overlap)
508+
509+
overlaps = torch.tensor(overlaps)
510+
print(f"Average overlap: {torch.mean(overlaps)}")
511+
print(f"Average biggest overlap: {torch.mean(torch.tensor(boverlaps))}")
512+
476513
print(f"Average Noise Reduction: {avg_noise_meter.mean}")
477514
print(f"Average Interrupt Line Reconstruction Accuracy: {avg_line_recon_accuracy_meter.mean}")
478515
print(f"Average Reconstruction Accuracy: {avg_recon_accuracy_meter.mean}")

src/main_lateral_connections.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from typing import Any, Dict, Optional
44

55
import lightning.pytorch as pl
6-
import numpy as np
76
import torch
8-
import torch.nn.functional as F
97
import wandb
108
from lightning import Fabric
119
from torch import Tensor
@@ -16,8 +14,8 @@
1614

1715
from data import loaders_from_config
1816
from lateral_connections.feature_extractor.straight_line_pl_modules import FixedFilterFeatureExtractor
19-
from lateral_connections.s2_rbm import L2RBM
2017
from lateral_connections.s1_lateral_connections import LateralNetwork
18+
from lateral_connections.s2_rbm import L2RBM
2119
from tools import loggers_from_conf
2220
from tools.store_load_run import load_run, save_run
2321
from utils import get_config, print_start, print_warn
@@ -216,7 +214,7 @@ def cycle(
216214
z_float, z = lateral_network(x_in)
217215

218216
# z2, z2_feedback, h, loss = l2.eval_step(z)
219-
#
217+
#
220218
# if epoch > 10:
221219
# mask_active = (z > 0) | (z2_feedback > 0)
222220
# if F.mse_loss(z[mask_active], z2_feedback[mask_active]) < .1:
@@ -263,7 +261,7 @@ def cycle(
263261

264262
if store_tensors:
265263
return features, torch.stack(input_features, dim=1), torch.stack(lateral_features, dim=1), torch.stack(
266-
lateral_features_f, dim=1),None, None
264+
lateral_features_f, dim=1), None, None
267265

268266

269267
def single_train_epoch(
@@ -349,10 +347,10 @@ def single_eval_epoch(
349347

350348
assert not wandb_b or wandb_b and store_plots, "Wandb logging requires storing the plots."
351349

352-
if False:
350+
if plot:
353351
if epoch == 0:
354352
feature_extractor.plot_model_weights(show_plot=plot)
355-
#
353+
#
356354
plots_fp = lateral_network.plot_samples(plt_img,
357355
plt_features,
358356
plt_input_features,

src/plot_robustness_paper.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pandas as pd
55
from matplotlib import pyplot as plt
66

7-
JSON_FP = Path(".").absolute().parent / "tmp" / "experiment_results.json"
7+
JSON_FP = Path(".").absolute().parent / "tmp" / "experiment_results_li.json"
8+
# JSON_FP = Path(".").absolute().parent / "tmp" / "experiment_results.json"
89

910

1011
def replace_square_list(sl):
@@ -27,10 +28,11 @@ def get_data():
2728
result = {
2829
'noise': data['config']['noise'],
2930
'line_interrupt': data['config']['line_interrupt'],
30-
'act_threshold': data['config']['lateral_model']['l1_params']['act_threshold'],
31+
'act_bias': round(data['config']['lateral_model']['l1_params']['act_threshold']-0.5, 2),
3132
'square_factor': replace_square_list(data['config']['lateral_model']['l1_params']['square_factor']),
3233
'noise_reduction': data['noise_reduction'],
3334
'avg_line_recon_accuracy_meter': data['avg_line_recon_accuracy_meter'],
35+
'avg_line_recon_accuracy_meter_2': (data['avg_line_recon_accuracy_meter']-0.75)/0.25,
3436
'recon_accuracy': data['recon_accuracy'],
3537
'recon_recall': data['recon_recall'],
3638
'recon_precision': data['recon_precision'],
@@ -46,24 +48,28 @@ def feature_noise_to_location_noise(feature_noise, round_=False):
4648
result = np.round(result, 2)
4749
return result
4850

49-
def plot_line(data, x_key, x_label, y_key, y_label, z_key, z_label, plot_key, plot_label, x2_func=None, x2_label=None):
51+
def plot_line(data, x_key, x_label, y_key, y_label, z_key, z_label, plot_key, plot_label, ymin=None,
52+
x2_func=None, x2_label=None):
5053

51-
fig, axs = plt.subplots(ncols=len(data[plot_key].unique()), figsize=(18, 6), dpi=100)
54+
fig, axs = plt.subplots(ncols=len(data[plot_key].unique()), figsize=(13, 3), dpi=100)
5255

53-
for ax, pk in zip(axs, data[plot_key].unique()):
56+
for ax, pk in zip(axs, sorted(data[plot_key].unique())):
5457
data_ = data[data[plot_key] == pk]
5558

5659
z_values = list(data_[z_key].unique())
5760
z_values = sorted(z_values)
5861

5962
for zv in z_values:
6063
z = data_[data_[z_key] == zv]
61-
ax.plot(z[x_key].values, z[y_key].values, label="{} = {}".format(z_label, zv))
64+
ax.plot(z[x_key].values, z[y_key].values, label="{}{}".format(z_label, zv))
6265

63-
ax.set_title("{} = {}".format(plot_label, pk))
66+
ax.set_title("{}{}".format(plot_label, pk))
6467
ax.set_xlabel(x_label)
6568
ax.set_ylabel(y_label)
6669

70+
if ymin is not None:
71+
ax.set_ylim(ymax=1.0+0.02, ymin=ymin-0.02)
72+
6773
if x2_func is not None:
6874
ax2 = ax.twiny()
6975
ax2.set_xlim(ax.get_xlim())
@@ -74,6 +80,7 @@ def plot_line(data, x_key, x_label, y_key, y_label, z_key, z_label, plot_key, pl
7480
ax.legend()
7581
ax.grid()
7682

83+
plt.tight_layout()
7784
plt.show()
7885

7986

@@ -85,26 +92,31 @@ def plot(data):
8592
# plot noise only
8693
data_1 = data[data['line_interrupt'] == 0]
8794
plot_line(data_1, x_key="noise", x_label="Feature Noise", y_key="noise_reduction", y_label="Noise Reduction Rate",
88-
z_key='act_threshold', z_label='Act. Threshold', plot_key='square_factor', plot_label='Square Factor',
89-
x2_func=feature_noise_to_location_noise, x2_label="Spatial Noise")
95+
z_key='act_bias', z_label='Act. Bias b = ', plot_key='square_factor', plot_label='Power Factor γ = ',
96+
ymin=0.4, x2_func=feature_noise_to_location_noise, x2_label="Spatial Noise")
9097

9198
plot_line(data_1, x_key="noise", x_label="Feature Noise", y_key="recon_recall", y_label="Recall",
92-
z_key='act_threshold', z_label='Act. Threshold', plot_key='square_factor', plot_label='Square Factor',
93-
x2_func=feature_noise_to_location_noise, x2_label="Spatial Noise")
99+
z_key='act_bias', z_label='Act. Bias b = ', plot_key='square_factor', plot_label='Power Factor γ = ',
100+
ymin=0.2, x2_func=feature_noise_to_location_noise, x2_label="Spatial Noise")
94101

95102
plot_line(data_1, x_key="noise", x_label="Feature Noise", y_key="recon_precision", y_label="Precision",
96-
z_key='act_threshold', z_label='Act. Threshold', plot_key='square_factor', plot_label='Square Factor',
97-
x2_func=feature_noise_to_location_noise, x2_label="Spatial Noise")
103+
z_key='act_bias', z_label='Act. Bias b = ', plot_key='square_factor', plot_label='Power Factor γ = ',
104+
ymin=0.1, x2_func=feature_noise_to_location_noise, x2_label="Spatial Noise")
98105

99106
# plot line interrupt only
100107
data_1 = data[data['noise'] == 0.0]
101108
plot_line(data_1, x_key="line_interrupt", x_label="Line Interrupt", y_key="avg_line_recon_accuracy_meter",
102-
y_label="Feature Reconstruction Rate",
103-
z_key='act_threshold', z_label='Act. Threshold', plot_key='square_factor', plot_label='Square Factor')
109+
y_label="Feature Reconstruction Rate", z_key='act_bias', z_label='Act. Bias b = ',
110+
plot_key='square_factor', plot_label='Power Factor γ = ', ymin=0.0)
111+
# plot_line(data_1, x_key="line_interrupt", x_label="Line Interrupt", y_key="avg_line_recon_accuracy_meter_2",
112+
# y_label="Feature Reconstruction Rate 2", z_key='act_bias', z_label='Act. Bias b = ',
113+
# plot_key='square_factor', plot_label='Power Factor γ = ', ymin=0.0)
104114
plot_line(data_1, x_key="line_interrupt", x_label="Line Interrupt", y_key="recon_recall", y_label="Recall",
105-
z_key='act_threshold', z_label='Act. Threshold', plot_key='square_factor', plot_label='Square Factor')
115+
z_key='act_bias', z_label='Act. Bias b = ', plot_key='square_factor', plot_label='Power Factor γ = ',
116+
ymin=0.5)
106117
plot_line(data_1, x_key="line_interrupt", x_label="Line Interrupt", y_key="recon_precision", y_label="Precision",
107-
z_key='act_threshold', z_label='Act. Threshold', plot_key='square_factor', plot_label='Square Factor')
118+
z_key='act_bias', z_label='Act. Bias b = ', plot_key='square_factor', plot_label='Power Factor γ = ',
119+
ymin=0.6)
108120

109121

110122
if __name__ == '__main__':

0 commit comments

Comments
 (0)