-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathlda_gibbs_example.m
176 lines (117 loc) · 3.87 KB
/
lda_gibbs_example.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
% ----------------------------------
% A simple illustration of LDA
% (1) generate W according to the LDA's generative model
% (2) Perform Gibbs Sampling
% (3) part of teaching for USTC computation statistics
%
% By Richard Xu
%
% 2017/04/10
%
% ----------------------------------
function lda_gibbs_example
clc;
clear;
%check_dirrand
% K is number of topics
K = 5;
% alpha = [2 ... 2] with K elements
alphas = ones(1,K)*2;
% V is the total number of words in a corpus.
V = 20;
% eta = [ 2 ... 2] with V elements
etas = ones(1,V)*2;
% D is the number of documents
D = 6;
% N is number of words in a document
N = 10;
% ------------------------------------------------
% Generate W, the words data
% ------------------------------------------------
[betas, thetas, Z, W] = generate_LDA(K,V,D,N, alphas, etas);
% ------------------------------------------------
% Now we have W, start Gibbs Sampling
% ------------------------------------------------
[betas, thetas, Z, temp] = generate_LDA(K,V,D,N, alphas, etas);
for g=1:100
% sample theta
for d = 1:D
I = zeros(1,K);
for i = 1:K
I(i) = length(find(Z(d,:) == i));
end
thetas(d,:) = dirrnd(alphas + I) ;
end
% sample beta
I = zeros(1,V);
for i = 1:V
I(i) = length(find(W == i));
end
for i = 1:K
betas(i,:) = dirrnd(etas + I);
end
% sample Z
for d =1:D
for n =1:N
probs = zeros(1,K);
for k =1:K
%Pr( Z_{d,n} = k)
probs(k) = thetas(d,k) * betas(k,W(d,n));
end
probs = probs / sum(probs);
Z(d,n) = find(mnrnd(1,probs));
end
end
fprintf('at %d iterations\n',g);
display(thetas);
display(betas);
display(Z);
end
end
% ------------------------------------------------------------------
% Generate LDA according to the generative model
% ------------------------------------------------------------------
function [betas, thetas, Z, W] = generate_LDA(K,V,D,N,alphas, etas)
% beta is the word distribution per topic
betas = zeros(K,V);
for k = 1:K
betas(k,:) = dirrnd(etas);
end
Z = zeros(D,N);
W = zeros(D,N);
thetas = zeros(D,K);
for d = 1:D
% for each document d, theta is the topic distribution for that document
thetas(d,:) = dirrnd(alphas);
temp = mnrnd(1,thetas(d,:),N);
for n =1:N
Z(d,n) = find(temp(n,:));
W(d,n) = find(mnrnd(1,betas(Z(d,n),:)),1);
end
end
end
% ------------------------------------------------------------------
% check to see if dirichlet distribution is generated properly
% ------------------------------------------------------------------
function check_dirrand
alphas = [1 2 3];
for i =1:10000
g(i,:) = dirrnd(alphas);
end
% mean of Dirichlet distribution should be alpha/sum(alpha)
display(mean(g));
% This is the emperical variance.
display(var(g));
% This is the theoretical variance.
emp_var = alphas .* ( repmat(sum(alphas),[1 3]) - alphas)/sum(alphas)^2/(sum(alphas) + 1);
display(emp_var);
end
% ------------------------------------------------------------------
% Since MATLAB does not provide any method for sampling dirichlet
% distribution, we need to sample it using Gamma distribution.
% ------------------------------------------------------------------
function g = dirrnd(alphas)
g = gamrnd(alphas,1);
g = g/sum(g);
end