forked from amber0309/HSIC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHSIC.py
110 lines (76 loc) · 2.57 KB
/
HSIC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
python implementation of Hilbert Schmidt Independence Criterion
hsic_gam implements the HSIC test using a Gamma approximation
Python 2.7.12
Gretton, A., Fukumizu, K., Teo, C. H., Song, L., Scholkopf, B.,
& Smola, A. J. (2007). A kernel statistical test of independence.
In Advances in neural information processing systems (pp. 585-592).
Shoubo (shoubo.sub AT gmail.com)
09/11/2016
Inputs:
X n by dim_x matrix
Y n by dim_y matrix
alph level of test
Outputs:
testStat test statistics
thresh test threshold for level alpha test
"""
from __future__ import division
import numpy as np
from scipy.stats import gamma
def rbf_dot(pattern1, pattern2, deg):
size1 = pattern1.shape
size2 = pattern2.shape
G = np.sum(pattern1*pattern1, 1).reshape(size1[0],1)
H = np.sum(pattern2*pattern2, 1).reshape(size2[0],1)
Q = np.tile(G, (1, size2[0]))
R = np.tile(H.T, (size1[0], 1))
H = Q + R - 2* np.dot(pattern1, pattern2.T)
H = np.exp(-H/2/(deg**2))
return H
def hsic_gam(X, Y, alph = 0.5):
"""
X, Y are numpy vectors with row - sample, col - dim
alph is the significance level
auto choose median to be the kernel width
"""
n = X.shape[0]
# ----- width of X -----
Xmed = X
G = np.sum(Xmed*Xmed, 1).reshape(n,1)
Q = np.tile(G, (1, n) )
R = np.tile(G.T, (n, 1) )
dists = Q + R - 2* np.dot(Xmed, Xmed.T)
dists = dists - np.tril(dists)
dists = dists.reshape(n**2, 1)
width_x = np.sqrt( 0.5 * np.median(dists[dists>0]) )
# ----- -----
# ----- width of X -----
Ymed = Y
G = np.sum(Ymed*Ymed, 1).reshape(n,1)
Q = np.tile(G, (1, n) )
R = np.tile(G.T, (n, 1) )
dists = Q + R - 2* np.dot(Ymed, Ymed.T)
dists = dists - np.tril(dists)
dists = dists.reshape(n**2, 1)
width_y = np.sqrt( 0.5 * np.median(dists[dists>0]) )
# ----- -----
bone = np.ones((n, 1), dtype = float)
H = np.identity(n) - np.ones((n,n), dtype = float) / n
K = rbf_dot(X, X, width_x)
L = rbf_dot(Y, Y, width_y)
Kc = np.dot(np.dot(H, K), H)
Lc = np.dot(np.dot(H, L), H)
testStat = np.sum(Kc.T * Lc) / n
varHSIC = (Kc * Lc / 6)**2
varHSIC = ( np.sum(varHSIC) - np.trace(varHSIC) ) / n / (n-1)
varHSIC = varHSIC * 72 * (n-4) * (n-5) / n / (n-1) / (n-2) / (n-3)
K = K - np.diag(np.diag(K))
L = L - np.diag(np.diag(L))
muX = np.dot(np.dot(bone.T, K), bone) / n / (n-1)
muY = np.dot(np.dot(bone.T, L), bone) / n / (n-1)
mHSIC = (1 + muX * muY - muX - muY) / n
al = mHSIC**2 / varHSIC
bet = varHSIC*n / mHSIC
thresh = gamma.ppf(1-alph, al, scale=bet)[0][0]
return (testStat, thresh)