-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
92 lines (75 loc) · 2.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import matplotlib.pyplot as plt, numpy as np
from itertools import combinations
from ggs import GGS
def apply_ggs(data, fs=0.5, lmbda=1, plot=False):
data = data.T if len(data.shape) != 1 else data[None, ...]
kmax = data.shape[-1] // (60 * 5 * fs)
bps, _ = GGS(data, int(kmax), lmbda)
bps = bps[-1] if isinstance(bps[0], list) else bps
if plot:
plot_ggs(data[0], bps)
return bps
def plot_ggs(signal, bps):
plt.figure(figsize=(20, 4))
plt.plot(signal)
for x in bps:
plt.axvline(x=x, linestyle="--", color="black")
plt.show()
def segment_ts(ts, bps):
X = [ts[bps[i] : bps[i + 1]] for i in range(len(bps) - 1)]
lens = [len(li) for li in X]
X = [
np.pad(X[i], (0, max(lens) - len(X[i])), constant_values=(0, np.nan))
for i in range(len(X))
]
return np.stack(X)
def plot_cluster(signal, gt_bps, clusters, bps):
plt.rcParams.update({"font.size": 14})
plt.figure(figsize=(20, 4))
norm = 60 * 0.5
time = np.linspace(0, len(signal) / norm, num=len(signal))
plt.plot(time, signal, color="black")
plt.xlim(0, len(signal) / norm)
plt.xticks(np.arange(0, len(signal) / norm, 5))
plt.xlabel("Time (min)")
plt.ylabel("Clustered Ground Truth")
num = len(set(clusters))
colors = (
["#003f5c", "#58508d", "#bc5090", "#ff6361", "#ffa600"]
if num == 5
else ["#003f5c", "#7a5195", "#ef5675", "#ffa600"]
if num == 4
else ["#003f5c", "#bc5090", "#ffa600"]
if num == 3
else ["#003f5c", "#ffa600"]
)
flag = [-1] * num
for i in range(len(gt_bps) - 1):
flag[clusters[i]] += 1
plt.axvspan(
gt_bps[i] / norm,
gt_bps[i + 1] / norm,
facecolor=colors[clusters[i]],
alpha=0.3,
label="_" * flag[clusters[i]] + f"cluster {clusters[i] + 1}",
)
handles, labels = plt.gca().get_legend_handles_labels()
handles, labels = np.array(handles), np.array(labels)
plt.legend(handles[labels.argsort()], labels[labels.argsort()], loc="upper right")
for x in bps:
plt.axvline(x=x / norm, linestyle="--", color="black")
plt.show()
def jaccard(set1, set2):
intersection = len(np.intersect1d(set1, set2))
union = len(set1) + len(set2) - intersection
return float(intersection) / union
def covering_metric(bps, gt_bps, length):
cover = 0
for i in range(len(gt_bps) - 1):
set1 = np.arange(gt_bps[i], gt_bps[i + 1])
jaccards = [
jaccard(set1, np.arange(bps1, bps2))
for (bps1, bps2) in combinations(bps, 2)
]
cover += len(set1) * max(jaccards)
return cover / length