-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgmm_fit.py
139 lines (121 loc) · 5.36 KB
/
gmm_fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
from scipy.stats import multivariate_normal as mvn_pdf
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans
from mixture import GaussianMixture
import pymesh
def compute_gmm(x,k=2,w=None,iter_max=10000,i_tol=1e-9,e_tol=1e-3):
km = MiniBatchKMeans(k)
if w is None:
w = np.ones(x.shape[0])
km.fit(x)
labels = km.labels_
mu = km.cluster_centers_
sigma = []
for i in range(k):
new_sigma = np.identity(x.shape[1])*i_tol
pts = x[(km.labels_ == i),:]
sigma.append(np.cov(pts,rowvar=False) + new_sigma)
sigma = np.array(sigma)
#mu = x[np.random.choice(x.shape[0],k,replace=False),:]
#sigma = np.array([(x.std()/k)*np.identity(x.shape[1]) for _ in range(k)])
pi = np.ones(shape=k)/k
mu_prev = mu.copy()
for iternum in range(iter_max):
# e-step
gamma = np.zeros(shape=(x.shape[0],k))
g2 = np.zeros(shape=(x.shape[0],k))
for i in range(k):
gamma_i = pi[i]*mvn_pdf.pdf(x,mean=mu[i],cov=sigma[i])
gamma[:,i] = gamma_i
g2[:,i] = pi[i]*mvn_pdf.pdf(x,mean=mu[i],cov=sigma[i])
g2 = np.copy(gamma)
#gamma = w.reshape((-1,1)) * gamma
gamma = gamma/gamma.sum(1,keepdims=True)
g2 = g2/g2.sum(1,keepdims=True)
print(np.linalg.norm(gamma-g2),gamma.shape,w.shape)
# m-step
for i in range(k):
new_mu = np.zeros(x.shape[1])
for j in range(x.shape[0]):
new_mu += gamma[j,i] * x[j,:]
new_mu /= gamma.sum(0)[0]
mu[i,:] = new_mu
new_sigma = np.identity(x.shape[1])*i_tol
for j in range(x.shape[0]):
xv = x[j,:][:,np.newaxis]
xm = new_mu[:,np.newaxis]
xd = xv - xm
new_sigma += gamma[j,i] * (xd @ xd.T)
new_sigma /= gamma.sum(0)[0]
sigma[i,:,:] = new_sigma
pi = gamma.mean(0)
if ((mu-mu_prev)**2).sum() < e_tol:
break
mu_prev = mu.copy()
print(iternum)
return mu,sigma,pi
mesh0 = pymesh.load_mesh("bunny/bun_zipper_1000_1.ply")
mesh1 = pymesh.load_mesh("bunny/bun_zipper_992_1.ply")
mesh2 = pymesh.load_mesh("bunny/bun_zipper_pts_1000_1.ply")
mesh3 = pymesh.load_mesh("bunny/bun_zipper_pds_1000_1.ply")
#mesh3 = pymesh.load_mesh("bunny/bun_zipper_res4_pds.ply")
mesh4 = pymesh.load_mesh("bunny/bun_zipper_50k.ply")
def get_tri_covar(tris):
covars = []
for face in tris:
A = face[0][:,None]
B = face[1][:,None]
C = face[2][:,None]
M = (A+B+C)/3
covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T)
return np.array(covars)*(1/12.0)
def get_centroids(mesh):
# obtain a vertex for each face index
face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1))
# face_vert is size (faces,3(one for each vert), 3(one for each dimension))
centroids = face_vert.sum(1)/3.0
ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:]
areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0
return centroids, areas,face_vert
coma,aa,fv1 = get_centroids(mesh0)
com,a,fv2 = get_centroids(mesh1)
a = a/a.min()
aa = aa/aa.min()
data_covar1 = get_tri_covar(fv1)
data_covar2 = get_tri_covar(fv2)
#verts = mesh2.vertices#[np.random.choice(mesh2.vertices.shape[0], com.shape[0], replace=False), :]
#res = compute_gmm(com,100,a)
#res2 = compute_gmm(verts,100)
#raise
with open('bunny_fit_monday_subsamples_25.log','w') as fout:
for km in [6,12,25,50,100,200,400]:
for init in ['random','kmeans']:
for exp_n in range(10):
gm0 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm0.set_covars(data_covar1); gm0.set_areas(aa); gm0.fit(coma); gm0.set_covars(None); gm0.set_areas(None)
gm1 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm1.set_covars(data_covar2); gm1.set_areas(a); gm1.fit(com); gm1.set_covars(None); gm1.set_areas(None)
gm2 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm2.fit(mesh3.vertices)
gm3 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm3.fit(mesh2.vertices)
#gm3 = GaussianMixture(100); gm3.fit(mesh4.vertices)
#print(coma.shape[0],com.shape[0],mesh2.vertices.shape[0],mesh3.vertices.shape[0])
s0 = gm0.score(mesh4.vertices)
s1 = gm1.score(mesh4.vertices)
s2 = gm2.score(mesh4.vertices)
s3 = gm3.score(mesh4.vertices)
#print(gm0.n_iter_,gm1.n_iter_)
#print(gm2.n_iter_,gm3.n_iter_)
#print(s0,s1)
#print(s2,s3)
fout.write("{},{},{},{},{}\n".format(km,init,'0',s0,gm0.n_iter_))
fout.write("{},{},{},{},{}\n".format(km,init,'1',s1,gm1.n_iter_))
fout.write("{},{},{},{},{}\n".format(km,init,'2',s2,gm2.n_iter_))
fout.write("{},{},{},{},{}\n".format(km,init,'3',s3,gm3.n_iter_))
#print(gm1.aic(mesh4.vertices),gm2.aic(mesh4.vertices))#,gm3.aic(mesh4.vertices))
#print((res[2] >0).sum(),(res2[2] >0).sum())
if False:
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as m3d
ax = m3d.Axes3D(plt.figure())
ax.scatter(com[:,0],com[:,1],com[:,2],s=a)
ax.scatter(verts[:,0],verts[:,1],verts[:,2],s=20)
plt.show()