11use crate :: algorithms:: { Status , StopReason } ;
22use crate :: prelude:: algorithms:: Algorithms ;
33
4- pub use crate :: routines:: estimation:: ipm:: burke;
4+ pub use crate :: routines:: estimation:: ipm:: { burke, burke_ipm , burke_log } ;
55pub use crate :: routines:: estimation:: qr;
6+ use crate :: routines:: math:: logsumexp;
67use crate :: routines:: settings:: Settings ;
78
89use crate :: routines:: output:: { cycles:: CycleLog , cycles:: NPCycle , NPResult } ;
9- use crate :: structs:: psi:: { calculate_psi , Psi } ;
10+ use crate :: structs:: psi:: { calculate_psi_dispatch , Psi } ;
1011use crate :: structs:: theta:: Theta ;
1112use crate :: structs:: weights:: Weights ;
1213
@@ -160,8 +161,24 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
160161 if ( self . last_objf - self . objf ) . abs ( ) <= THETA_G && self . eps > THETA_E {
161162 self . eps /= 2. ;
162163 if self . eps <= THETA_E {
163- let pyl = psi * w. weights ( ) ;
164- self . f1 = pyl. iter ( ) . map ( |x| x. ln ( ) ) . sum ( ) ;
164+ // Compute f1 = sum(log(pyl)) where pyl = psi * w
165+ self . f1 = if self . psi . is_log_space ( ) {
166+ // For log-space: f1 = sum_i(logsumexp(log_psi[i,:] + log(w)))
167+ let log_w: Vec < f64 > = w. weights ( ) . iter ( ) . map ( |& x| x. ln ( ) ) . collect ( ) ;
168+ ( 0 ..psi. nrows ( ) )
169+ . map ( |i| {
170+ let combined: Vec < f64 > = ( 0 ..psi. ncols ( ) )
171+ . map ( |j| * psi. get ( i, j) + log_w[ j] )
172+ . collect ( ) ;
173+ logsumexp ( & combined)
174+ } )
175+ . sum ( )
176+ } else {
177+ // For regular space: f1 = sum(log(psi * w))
178+ let pyl = psi * w. weights ( ) ;
179+ pyl. iter ( ) . map ( |x| x. ln ( ) ) . sum ( )
180+ } ;
181+
165182 if ( self . f1 - self . f0 ) . abs ( ) <= THETA_F {
166183 tracing:: info!( "The model converged after {} cycles" , self . cycle, ) ;
167184 self . set_status ( Status :: Stop ( StopReason :: Converged ) ) ;
@@ -197,31 +214,29 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
197214 }
198215
199216 fn estimation ( & mut self ) -> Result < ( ) > {
200- self . psi = calculate_psi (
217+ let use_log_space = self . settings . advanced ( ) . log_space ;
218+
219+ self . psi = calculate_psi_dispatch (
201220 & self . equation ,
202221 & self . data ,
203222 & self . theta ,
204223 & self . error_models ,
205224 self . cycle == 1 && self . settings . config ( ) . progress ,
206225 self . cycle != 1 ,
226+ use_log_space,
207227 ) ?;
208228
209229 if let Err ( err) = self . validate_psi ( ) {
210230 bail ! ( err) ;
211231 }
212232
213- ( self . lambda , _) = match burke ( & self . psi ) {
214- Ok ( ( lambda, objf) ) => ( lambda, objf) ,
215- Err ( err) => {
216- bail ! ( "Error in IPM during estimation: {:?}" , err) ;
217- }
218- } ;
233+ ( self . lambda , _) = burke_ipm ( & self . psi )
234+ . map_err ( |err| anyhow:: anyhow!( "Error in IPM during estimation: {:?}" , err) ) ?;
219235 Ok ( ( ) )
220236 }
221237
222238 fn condensation ( & mut self ) -> Result < ( ) > {
223239 // Filter out the support points with lambda < max(lambda)/1000
224-
225240 let max_lambda = self
226241 . lambda
227242 . iter ( )
@@ -273,20 +288,16 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
273288 self . psi . filter_column_indices ( keep. as_slice ( ) ) ;
274289
275290 self . validate_psi ( ) ?;
276- ( self . lambda , self . objf ) = match burke ( & self . psi ) {
277- Ok ( ( lambda, objf) ) => ( lambda, objf) ,
278- Err ( err) => {
279- return Err ( anyhow:: anyhow!(
280- "Error in IPM during condensation: {:?}" ,
281- err
282- ) ) ;
283- }
284- } ;
291+
292+ ( self . lambda , self . objf ) = burke_ipm ( & self . psi )
293+ . map_err ( |err| anyhow:: anyhow!( "Error in IPM during condensation: {:?}" , err) ) ?;
285294 self . w = self . lambda . clone ( ) ;
286295 Ok ( ( ) )
287296 }
288297
289298 fn optimizations ( & mut self ) -> Result < ( ) > {
299+ let use_log_space = self . settings . advanced ( ) . log_space ;
300+
290301 self . error_models
291302 . clone ( )
292303 . iter_mut ( )
@@ -298,8 +309,6 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
298309 }
299310 } )
300311 . try_for_each ( |( outeq, em) | -> Result < ( ) > {
301- // OPTIMIZATION
302-
303312 let gamma_up = em. factor ( ) ? * ( 1.0 + self . gamma_delta [ outeq] ) ;
304313 let gamma_down = em. factor ( ) ? / ( 1.0 + self . gamma_delta [ outeq] ) ;
305314
@@ -309,35 +318,32 @@ impl<E: Equation + Send + 'static> Algorithms<E> for NPAG<E> {
309318 let mut error_model_down = self . error_models . clone ( ) ;
310319 error_model_down. set_factor ( outeq, gamma_down) ?;
311320
312- let psi_up = calculate_psi (
321+ let psi_up = calculate_psi_dispatch (
313322 & self . equation ,
314323 & self . data ,
315324 & self . theta ,
316325 & error_model_up,
317326 false ,
318327 true ,
328+ use_log_space,
319329 ) ?;
320- let psi_down = calculate_psi (
330+
331+ let psi_down = calculate_psi_dispatch (
321332 & self . equation ,
322333 & self . data ,
323334 & self . theta ,
324335 & error_model_down,
325336 false ,
326337 true ,
338+ use_log_space,
327339 ) ?;
328340
329- let ( lambda_up, objf_up) = match burke ( & psi_up) {
330- Ok ( ( lambda, objf) ) => ( lambda, objf) ,
331- Err ( err) => {
332- bail ! ( "Error in IPM during optim: {:?}" , err) ;
333- }
334- } ;
335- let ( lambda_down, objf_down) = match burke ( & psi_down) {
336- Ok ( ( lambda, objf) ) => ( lambda, objf) ,
337- Err ( err) => {
338- bail ! ( "Error in IPM during optim: {:?}" , err) ;
339- }
340- } ;
341+ let ( lambda_up, objf_up) = burke_ipm ( & psi_up)
342+ . map_err ( |err| anyhow:: anyhow!( "Error in IPM during optim: {:?}" , err) ) ?;
343+
344+ let ( lambda_down, objf_down) = burke_ipm ( & psi_down)
345+ . map_err ( |err| anyhow:: anyhow!( "Error in IPM during optim: {:?}" , err) ) ?;
346+
341347 if objf_up > self . objf {
342348 self . error_models . set_factor ( outeq, gamma_up) ?;
343349 self . objf = objf_up;
0 commit comments