-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHMM3.java
125 lines (109 loc) · 3.72 KB
/
HMM3.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import java.util.Scanner;
import java.lang.Math;
import java.util.Locale;
public class HMM3 {
/* --- all variables refer to tutorial stamp --- */
// length of observation sequence
public int _t;
// number of states in the model
public int _n;
// number of observation symbols
public int _m;
// state transition probabilities
public double _a[][];
// observation probability matrix
public double _b[][];
// initial state distribution
public double _pi[];
// observation sequence
public int _o[];
public HMM3() {}
public double _log(double foo) {
if (foo != 0.0) {
return Math.log(foo);
} else {
return -1;
}
}
public static void main( String[] args ) {
HMM3 h = new HMM3();
// 1. read the matrices into buffer
Locale.setDefault(Locale.ENGLISH);
Scanner sc = new Scanner(System.in);
/* below process dedicated for HMM3 */
h._n = sc.nextInt();
if (h._n != sc.nextInt()) {
System.out.println("*** var error: N mismatch ***");
}
h._a = new double[h._n][h._n];
for (int i = 0; i < h._n; i++) {
for (int j = 0; j < h._n; j++) {
h._a[i][j] = sc.nextDouble();//h._log(sc.nextDouble()); // A
}
}
if (h._n != sc.nextInt()) {
System.out.println("*** var error: N mismatch ***");
}
h._m = sc.nextInt();
h._b = new double[h._n][h._m]; // m * n
for (int i = 0; i < h._n; i++) {
for (int j = 0; j < h._m; j++) {
h._b[i][j] = sc.nextDouble();//h._log(sc.nextDouble()); // B
}
}
if (1 != sc.nextInt()) {
System.out.println("*** var error: pi mismatch ***");
}
if (h._n != sc.nextInt()) {
System.out.println("*** var error: N mismatch ***");
}
h._pi = new double[h._n];
for (int i = 0; i < h._n; i++) {
h._pi[i] = sc.nextDouble();//h._log(sc.nextDouble()); PI
}
h._o = new int[sc.nextInt()];
for (int i = 0; i < h._o.length; i++) {
h._o[i] = sc.nextInt();
}
// 2. viterbi algorithm
h._t = h._o.length;
double[][] _delta = new double[h._t][h._n];
int[][] _delta_idx = new int[h._t][h._n];
for (int i = 0; i < h._n; i++) {
_delta[0][i] = h._b[i][h._o[0]] * h._pi[i]; // my1(X1) e.g. basecase of viterbi
}
for (int t = 1; t < h._t; t++) {
for (int i = 0; i < h._n; i++) {
_delta[t][i] = 0.0;
for (int j = 0; j < h._n; j++) {
double foo = h._a[j][i] * _delta[t-1][j] * h._b[i][h._o[t]]; // probability of
if (foo > _delta[t][i]) {
_delta[t][i] = foo;
_delta_idx[t][i] = j;
//System.out.print(foo);
}
} /// for 3
//System-out.println();
//System.out.println(_delta[t][i]);
} // for 2
//System.out.println();
} // for 1
int[] _x = new int[h._t];
double foo = 0.0;
for (int j = 0; j < h._n; j++) {
if (_delta[h._t-1][j] > foo) {
foo = _delta[h._t-1][j];
_x[h._t-1] = j;
}
} // x[t-1]
for (int t = h._t-2; t >= 0; t--) {
_x[t] = _delta_idx[t+1][_x[t+1]];
}
// 3. print out the results
String str = "";
for (int i = 0; i < _x.length; i++) {
str += Integer.toString(_x[i]) + " ";
}
System.out.println(str);
}
}