-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
77 lines (67 loc) · 2.19 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import ctypes
import cupy
gmm = ctypes.cdll.LoadLibrary('./build/libgmm.so')
# 设置参数类型
gmm.gmmFit.argtypes = [
ctypes.POINTER(ctypes.c_double),
ctypes.POINTER(ctypes.c_double),
ctypes.POINTER(ctypes.c_double),
ctypes.POINTER(ctypes.c_double),
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_double,
ctypes.c_int
]
# # 生成 2 个三维高斯分布的数据
# dist1_mean = cupy.array([-1, 1, -1], dtype=cupy.float64)
# dist1_covar = cupy.array(
# [[ 3, -2, 0],
# [-2, 2, 0],
# [ 0, 0, 2]]
# , dtype=cupy.float64)
# dist1_data = cupy.random.multivariate_normal(dist1_mean, dist1_covar, size=7000)
# dist2_mean = cupy.array([2, -1.5, 3], dtype=cupy.float64)
# dist2_covar = cupy.array(
# [[ 3, 1, -5],
# [ 1, 1, -1],
# [-5, -1, 10]]
# , dtype=cupy.float64)
# dist2_data = cupy.random.multivariate_normal(dist2_mean, dist2_covar, size=3000)
# # 权重是 7:3
# data = cupy.concatenate([dist1_data, dist2_data])
# cupy.random.shuffle(data)
# weights = cupy.empty(2, dtype=cupy.float64)
# means = cupy.empty((2, 3), dtype=cupy.float64)
# covariances = cupy.empty((2, 3, 3), dtype=cupy.float64)
# gmm.gmmFit(
# ctypes.cast(data.data.ptr, ctypes.POINTER(ctypes.c_double)),
# ctypes.cast(weights.data.ptr, ctypes.POINTER(ctypes.c_double)),
# ctypes.cast(means.data.ptr, ctypes.POINTER(ctypes.c_double)),
# ctypes.cast(covariances.data.ptr, ctypes.POINTER(ctypes.c_double)),
# data.shape[0],
# data.shape[1],
# 2,
# 1e-6,
# 300
# )
# print('weights:\n', weights)
# print('means:\n', means)
# print('covariances:\n', covariances)
data = cupy.load('data/train-images.npy') / 255
cupy.random.shuffle(data)
print(data.shape)
weights = cupy.empty(10, dtype=cupy.float64)
means = cupy.empty((10, 784), dtype=cupy.float64)
covariances = cupy.empty((10, 784, 784), dtype=cupy.float64)
gmm.gmmFit(
ctypes.cast(data.data.ptr, ctypes.POINTER(ctypes.c_double)),
ctypes.cast(weights.data.ptr, ctypes.POINTER(ctypes.c_double)),
ctypes.cast(means.data.ptr, ctypes.POINTER(ctypes.c_double)),
ctypes.cast(covariances.data.ptr, ctypes.POINTER(ctypes.c_double)),
data.shape[0],
data.shape[1],
10,
1e-6,
300
)