-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathksvd.cpp
228 lines (191 loc) · 6.02 KB
/
ksvd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
// Copyright (C) 2019 Piotr (Peter) Beben <[email protected]>
// See LICENSE included.
#define EIGEN_NO_MALLOC
//#define DEBUG_KSVD
#include "ksvd.h"
#include "constants.h"
#include <Eigen/Dense>
#include <functional>
#include <math.h>
#include <vector>
#include <omp.h>
#include <iostream>
using std::cout;
using std::endl;
using std::vector;
using Eigen::MatrixXf;
using Eigen::MatrixXi;
using Eigen::Matrix;
using Eigen::VectorXf;
using Eigen::VectorXi;
using Eigen::Index;
using Eigen::Map;
using Eigen::Dynamic;
using Eigen::Aligned16;
//-----------------------------------------------------------
/**
The K-SVD dictionary learning algorithm.
Provided an integer L and n x k matrix Y consisting of our
k size n signal vectors, we look for an n x m dictionary
matrix D of m atoms (unit column vectors of size n) and a
sparse m x k matrix X of column code vectors that encode the
signals in Y as closely as possible using no more than L
atoms in D. In detail, we solve the minimization problem
min_{X, D} ||Y - DX||_F
subject to
||X||_0 <= L
where ||.||_F is the matrix Frobenius norm, and ||.||_0 is the
vector L_0 norm (the number of non-zero entries in a vector).
@param[in] useOpenMP: Whether to parallelize using OpenMP.
@param[in] Y: n x k signal matrix.
@param[in] latm: Sparsity constraint L.
@param[in] maxIters: Max. number of K-SVD iterations.
@param[in] maxError: Max. error ||Y-D*X||^2 before an iteration
can be aborted (< 0.0 for none).
@param[in] svdPowIters: Number of power iterations to approximate
first singular vectors.
@param[in] Sparse approximation functor.
@param[in/out] D in: first approximation n x m 'dictionary' matrix.
D out: learnt dictionary adapted to the signals Y.
@param[out] X: m x k 'code' matrix.
*/
void ksvd(
bool useOpenMP,
const MatrixXf& Y,
Index latm,
int maxIters,
float maxError,
int svPowIters,
const std::function<void(
const VectorXf&,
const MatrixXf&,
Index,
VectorXf&,
VectorXf&)> sparseFunct,
MatrixXf& D,
MatrixXf& X
)
{
Index ndim = D.rows();
Index natm = D.cols();
Index nsig = Y.cols();
assert(ndim == Y.rows());
assert(natm == X.rows());
assert(nsig == X.cols());
assert(maxIters >= 1);
assert(svPowIters >= 1);
assert(latm <= ndim && latm <= natm);
bool stopAtMaxError = (maxError >= 0.0f);
float maxErrorSq = 0.0f;
bool smallError;
float* errsig = new float[nsig];
for(Index i=0; i < nsig; ++i) { errsig[i] = float_infinity; }
if( stopAtMaxError ) maxErrorSq = maxError*maxError;
float *dwork = new float[nsig*ndim];
MatrixXf Z(ndim,nsig);
Map<MatrixXf, ALIGNEDX> Zblk(nullptr, 0, 0);
VectorXf ZTA(nsig);
VectorXf ZZTA(ndim);
VectorXf A(ndim);
VectorXf B(nsig);
MatrixXi iatmUsed(latm,nsig);
VectorXi natmUsed(nsig);
#pragma omp parallel if(useOpenMP) default(shared) firstprivate(sparseFunct)
{
VectorXf Ysig(ndim);
VectorXf Xsig(natm);
VectorXf R(ndim);
for(int iter = 1; iter <= maxIters; ++iter){
//***
#pragma omp single
{
#ifdef __DEBUG_KSVD
if(iter == 1) cout << "\nAverge error (coordinate difference):\n" ;
cout << (Y-(D*X)).cwiseAbs().sum()/(ndim*nsig) << endl;
#endif
}
// Fix dictionary D and optimize code matrix X.
#pragma omp for schedule(dynamic)
for(Index isig = 0; isig < nsig; ++isig){
Ysig = Y.col(isig);
sparseFunct(Ysig, D, latm, Xsig, R);
float error = R.dot(R);
if( error <= errsig[isig] ) {
X.col(isig) = Xsig;
errsig[isig] = error;
}
}
if( stopAtMaxError ){
// Stop if Y and D*X are similar within tolerance.
#pragma omp single
{
smallError = true;
for(Index isig = 0; isig < nsig; ++isig){
if( errsig[isig] > maxErrorSq ){
smallError = false;
break;
}
}
}//single
if( smallError ) break;
}
// ---
// Now optimize dictionary D for the current code vector X
// one column (atom) at a time.
//#pragma omp for schedule(dynamic)
#pragma omp single
{
// queue up atoms used by each signal
for(Index isig = 0; isig < nsig; ++isig){
int ic = 0;
for(Index iatm = 0; iatm < natm; ++iatm){
if( X(iatm,isig) == 0.0f ) continue;
iatmUsed(ic,isig) = int(iatm);
++ic;
if(ic >= latm) break;
}
natmUsed(isig) = ic;
}
for(Index iatm = 0; iatm < natm; ++iatm){
A = D.col(iatm); // Original atom is our initial approx.
// Compute the matrix Z of residuals for current atom.
Index nsigUsing = 0;
for(Index isig = 0; isig < nsig; ++isig){
if( X(iatm,isig) == 0.0f ) continue;
Z.col(nsigUsing) = Y.col(isig);
for(Index i = 0; i < natmUsed(isig); ++i){
Index jatm = iatmUsed(i,isig);
if( jatm == iatm ) continue;
Z.col(nsigUsing) -= X(jatm,isig)*D.col(jatm);
}
++nsigUsing;
}
if( nsigUsing == 0 ) continue;
// Map to workspace
new (&Zblk) Map<MatrixXf, ALIGNEDX>(dwork,ndim,nsigUsing);
Zblk = Z.block(0,0,ndim,nsigUsing);
// We only need the first singular vector, do a power
// iteration to approximate it. This is our new improved atom.
for(int i=1; i <= svPowIters; ++i){
ZTA.segment(0,nsigUsing).noalias() = Zblk.transpose()*A;
ZZTA.noalias() = Zblk*ZTA.segment(0,nsigUsing);
A = ZZTA.normalized(); // Optimized atom
}
// The projection coefficients describe the code vector
// corresponding to the updated atom.
B.segment(0,nsigUsing).noalias() = Zblk.transpose()*A;
Index ic2 = 0;
for(Index isig = 0; isig < nsig; ++isig){
if( X(iatm,isig) == 0.0f ) continue;
X(iatm,isig) = B(ic2);
++ic2;
}
D.col(iatm) = A;
}
}//single
}
}//parallel
delete[] dwork;
if( stopAtMaxError ) delete[] errsig;
}
//-----------------------------------------------------------