-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathlds_messages_interface.py
111 lines (88 loc) · 4.07 KB
/
lds_messages_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from __future__ import division
import numpy as np
from numpy.lib.stride_tricks import as_strided
from functools import wraps, partial
################################
# distribution-form wrappers #
################################
from pylds.lds_messages import \
kalman_filter as _kalman_filter, \
rts_smoother as _rts_smoother, \
filter_and_sample as _filter_and_sample, \
kalman_filter_diagonal as _kalman_filter_diagonal, \
filter_and_sample_diagonal as _filter_and_sample_diagonal, \
filter_and_sample_randomwalk as _filter_and_sample_randomwalk, \
E_step as _E_step
def _ensure_ndim(X,T,ndim):
X = np.require(X,dtype=np.float64, requirements='C')
assert ndim-1 <= X.ndim <= ndim
if X.ndim == ndim:
assert X.shape[0] == T
return X
else:
return as_strided(X, shape=(T,)+X.shape, strides=(0,)+X.strides)
def _argcheck(mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data):
T = data.shape[0]
A, B, sigma_states, C, D, sigma_obs = \
map(partial(_ensure_ndim, T=T, ndim=3),
[A, B, sigma_states, C, D, sigma_obs])
# Check that the inputs are C ordered and at least 1d
inputs = np.require(inputs, dtype=np.float64, requirements='C')
data = np.require(data, dtype=np.float64, requirements='C')
return mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data
def _argcheck_diag_sigma_obs(mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data):
T = data.shape[0]
A, B, sigma_states, C, D, = \
map(partial(_ensure_ndim, T=T, ndim=3),
[A, B, sigma_states, C, D])
sigma_obs = _ensure_ndim(sigma_obs, T=T, ndim=2)
inputs = np.require(inputs, dtype=np.float64, requirements='C')
data = np.require(data, dtype=np.float64, requirements='C')
return mu_init, sigma_init, A, B, sigma_states, C, D, sigma_obs, inputs, data
def _argcheck_randomwalk(mu_init, sigma_init, sigmasq_states, sigmasq_obs, data):
T = data.shape[0]
sigmasq_states, sigmasq_obs = \
map(partial(_ensure_ndim, T=T, ndim=2),
[sigmasq_states, sigmasq_obs])
data = np.require(data, dtype=np.float64, requirements='C')
return mu_init, sigma_init, sigmasq_states, sigmasq_obs, data
def _wrap(func, check):
@wraps(func)
def wrapped(*args, **kwargs):
return func(*check(*args,**kwargs))
return wrapped
kalman_filter = _wrap(_kalman_filter,_argcheck)
rts_smoother = _wrap(_rts_smoother,_argcheck)
filter_and_sample = _wrap(_filter_and_sample,_argcheck)
E_step = _wrap(_E_step,_argcheck)
kalman_filter_diagonal = _wrap(_kalman_filter_diagonal,_argcheck_diag_sigma_obs)
filter_and_sample_diagonal = _wrap(_filter_and_sample_diagonal,_argcheck_diag_sigma_obs)
filter_and_sample_randomwalk = _wrap(_filter_and_sample_randomwalk,_argcheck_randomwalk)
###############################
# information-form wrappers #
###############################
from pylds.lds_info_messages import \
kalman_info_filter as _kalman_info_filter, \
info_E_step as _info_E_step, \
info_sample as _info_sample
def _info_argcheck(J_init, h_init, log_Z_init,
J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair,
J_node, h_node, log_Z_node):
T = h_node.shape[0]
assert np.isscalar(log_Z_init)
J_node = _ensure_ndim(J_node, T=T, ndim=3)
J_pair_11, J_pair_21, J_pair_22 = \
map(partial(_ensure_ndim, T=T-1, ndim=3),
[J_pair_11, J_pair_21, J_pair_22])
h_pair_1, h_pair_2 = \
map(partial(_ensure_ndim, T=T-1, ndim=2),
[h_pair_1, h_pair_2])
log_Z_pair = _ensure_ndim(log_Z_pair, T=T-1, ndim=1)
log_Z_node = _ensure_ndim(log_Z_node, T=T, ndim=1)
h_node = np.require(h_node, dtype=np.float64, requirements='C')
return J_init, h_init, log_Z_init, \
J_pair_11, J_pair_21, J_pair_22, h_pair_1, h_pair_2, log_Z_pair,\
J_node, h_node, log_Z_node
kalman_info_filter = _wrap(_kalman_info_filter, _info_argcheck)
info_E_step = _wrap(_info_E_step, _info_argcheck)
info_sample = _wrap(_info_sample, _info_argcheck)