-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSGL.py
137 lines (119 loc) · 4.48 KB
/
SGL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import math
import time
import numpy as np
from utils import Laplacian_dual, Laplacian_inv, Laplacian
class LearnGraphTopology:
def __init__(self, S, alpha=0, beta=1e4, n_iter=10000, c1=0., c2=1e10, tol = 1e-6):
self.tol = tol
self.S = S
self.p = S.shape[0]
self.n = np.int(self.p * (self.p - 1)/2)
self.alpha = alpha
self.beta = beta
self.n_iter = n_iter
self.c1 = c1
self.c2 = c2
def w_init(self, w0_init, Sinv):
"""
Initialization of w
"""
if w0_init == 'naive':
w0 = Laplacian_inv(Sinv, self.p, self.n)
w0[w0<0] = 0
else:
raise ValueError('Method not implemented')
return w0
def update_w(self, w, Lw, U, lambda_, K):
"""
Compute w update as equation 12
"""
c = Laplacian_dual(U @ np.diag(lambda_) @ U.T - K / self.beta, self.p, self.n)
grad_f = Laplacian_dual(Lw, self.p, self.n) - c
M_grad_f = Laplacian_dual(Laplacian(grad_f, self.p), self.p, self.n)
wT_M_grad_f = sum(w * M_grad_f)
dwT_M_dw = sum(grad_f * M_grad_f)
## exact line search
t = (wT_M_grad_f - sum(c * grad_f)) / dwT_M_dw
## no line search :
# p = int(0.5*(1 + np.sqrt(1 + 8*w.shape[0])))
# t = 0.5*p
w_new = w - t * grad_f
w_new[w_new < 0] = 0
return w_new
def adjacency(self, w):
'''
Compute the Adjacency matrix from w
'''
Aw = np.zeros((self.p,self.p))
k=0
for i in range(0, self.p):
for j in range(i+1,self.p):
Aw[i][j] = w[k]
k = k + 1
Aw = Aw + Aw.T
return Aw
def update_lambda(self, U, Lw, k):
"""
Compute lambda update as proposed in the supplementary
"""
q = Lw.shape[1] - k
d = np.diag(np.dot(U.T, np.dot(Lw, U)))
assert(d.shape[0] == q)
lambda_ = (d + np.sqrt(d**2 + 4/self.beta))/2
cond = (lambda_[q-1] - self.c2 <= 1e-9) and (lambda_[0] - self.c1 >= -1e-9) and np.all(lambda_[1:q] - lambda_[0:q-1] >= -1e-9)
if cond:
return lambda_
else:
lambda_[lambda_ < self.c1] = self.c1
lambda_[lambda_ > self.c2] = self.c2
cond = (lambda_[q-1] - self.c2 <= 1e-9) and (lambda_[0] - self.c1 >= -1e-9) and np.all(lambda_[1:q] - lambda_[0:q-1] >= -1e-9)
if cond:
return lambda_
else:
raise ValueError("Consider increasing value of beta")
def update_U(self, Lw, k):
"""
Compute U update as equation 14
"""
_, eigvec = np.linalg.eigh(Lw)
assert(eigvec.shape[1] == self.p)
return eigvec[:, k:]
def objective(self, Lw, lambda_, K, U):
"""
Compute objective function - equation 8
"""
term1 = np.sum(-np.log(lambda_))
term2 = np.trace(np.dot(K, Lw))
term3 = 0.5 * self.beta * np.linalg.norm(Lw - np.dot(U, np.dot(np.diag(lambda_), U.T)), ord='fro')**2
return term1 + term2 + term3
def learn_graph(self, k=1, w0_init='naive', eps = 1e-4):
# find an appropriate inital guess
Sinv = np.linalg.pinv(self.S)
# if w0 is either "naive" or "qp", compute it, else return w0
w = self.w_init(w0_init, Sinv)
# compute quantities on the initial guess
Lw = Laplacian(w, self.p)
# l1-norm penalty factor
H = self.alpha * (np.eye(self.p) - np.ones((self.p, self.p)))
K = self.S + H
U = self.update_U(Lw = Lw, k = k)
lambda_ = self.update_lambda(U = U, Lw = Lw, k = k)
objective_seq = []
for _ in range(self.n_iter):
w_new = self.update_w(w = w, Lw = Lw, U = U, lambda_ = lambda_, K = K)
Lw = Laplacian(w_new, self.p)
U = self.update_U(Lw = Lw, k = k)
lambda_ = self.update_lambda(U = U, Lw = Lw, k = k)
# check for convergence
convergence = np.linalg.norm(w_new - w, ord=2) < self.tol
objective_seq.append(self.objective(Lw, lambda_, K, U))
if convergence:
break
# update estimates
w = w_new
K = self.S + H / (-Lw + eps)
# compute the adjancency matrix
Aw = self.adjacency(w)
results = {'laplacian' : Lw, 'adjacency' : Aw, 'w' : w, 'lambda' : lambda_, 'U' : U,
'convergence' : convergence, 'objective_seq' : objective_seq}
return results