@@ -12,13 +12,18 @@ use p3_util::log2_strict_usize;
1212
1313use crate :: {
1414 air_builders:: symbolic:: symbolic_expression:: { SymbolicEvaluator , SymbolicExpression } ,
15- gkr:: { GkrArtifact , Layer , Layer :: LogUpGeneric } ,
15+ gkr:: {
16+ GkrArtifact ,
17+ Layer :: { self , LogUpGeneric } ,
18+ } ,
1619 interaction:: {
1720 gkr_log_up:: { num_interaction_dimensions, GkrAuxData , GkrLogUpPhase } ,
1821 trace:: Evaluator ,
1922 PairTraceView , SymbolicInteraction ,
2023 } ,
24+ parizip,
2125 poly:: multi:: { hypercube_eq_over_y, Mle } ,
26+ utils:: metrics_span_increment,
2227} ;
2328
2429/// A struct that holds the interaction-related data for a single GKR instance:
@@ -148,7 +153,7 @@ where
148153 } else {
149154 Some ( GkrLogUpInstance :: from_interactions (
150155 trace_view,
151- & interactions,
156+ interactions,
152157 beta_pows,
153158 ) )
154159 }
@@ -181,29 +186,30 @@ where
181186 } )
182187 . collect_vec ( ) ;
183188
184- let results: Vec < _ > = trace_view_per_air_filtered
185- . par_iter ( )
186- . zip ( gkr_instances. par_iter ( ) )
187- . zip ( gkr_artifact. n_variables_by_instance . par_iter ( ) )
188- . map ( |( ( trace_view, gkr_instance) , & n_vars) | {
189- let height = trace_view. partitioned_main [ 0 ] . height ( ) ;
190- let log_height = log2_strict_usize ( height) ;
191-
192- let instance_ood = & ood_point[ ood_point. len ( ) - n_vars..] ;
193- let ( _, r) = instance_ood. split_at ( n_vars - log_height) ;
194- debug_assert_eq ! ( r. len( ) , log_height) ;
195-
196- let ( after_challenge_trace, exposed_values, count_mle_claims, sigma_mle_claims) =
197- Self :: generate_aux ( gkr_instance, & gamma_pows, r) ;
198-
199- (
200- after_challenge_trace,
201- exposed_values,
202- count_mle_claims,
203- sigma_mle_claims,
204- )
205- } )
206- . collect ( ) ;
189+ let results: Vec < _ > = parizip ! (
190+ trace_view_per_air_filtered,
191+ gkr_instances,
192+ & gkr_artifact. n_variables_by_instance
193+ )
194+ . map ( |( trace_view, gkr_instance, & n_vars) | {
195+ let height = trace_view. partitioned_main [ 0 ] . height ( ) ;
196+ let log_height = log2_strict_usize ( height) ;
197+
198+ let instance_ood = & ood_point[ ood_point. len ( ) - n_vars..] ;
199+ let ( _, r) = instance_ood. split_at ( n_vars - log_height) ;
200+ debug_assert_eq ! ( r. len( ) , log_height) ;
201+
202+ let ( after_challenge_trace, exposed_values, count_mle_claims, sigma_mle_claims) =
203+ Self :: generate_aux ( gkr_instance, & gamma_pows, r) ;
204+
205+ (
206+ after_challenge_trace,
207+ exposed_values,
208+ count_mle_claims,
209+ sigma_mle_claims,
210+ )
211+ } )
212+ . collect ( ) ;
207213
208214 let mut results_iter = results. into_iter ( ) ;
209215
@@ -257,32 +263,38 @@ where
257263 . map ( |row| Mle :: eval_slice ( row, r) )
258264 . collect ( ) ;
259265
260- let s_at_rows: Vec < EF > = gkr_instance
261- . counts
262- . par_row_slices ( )
263- . zip ( gkr_instance. sigmas . par_row_slices ( ) )
264- . map ( |( count_row, sigma_row) | {
265- itertools:: interleave ( count_row, sigma_row)
266- . zip ( gamma_pows)
267- . fold ( EF :: ZERO , |acc, ( & val, & gamma_pow) | acc + gamma_pow * val)
268- } )
269- . collect ( ) ;
270-
271- let eqs_at_r = hypercube_eq_over_y ( r) ;
272-
273- let mut partial_sum = EF :: ZERO ;
274- let mut after_challenge_trace_data = Vec :: with_capacity ( eqs_at_r. len ( ) * 3 ) ;
275- for ( eq_at_r, s_at_row) in eqs_at_r. iter ( ) . zip ( s_at_rows. iter ( ) ) {
276- after_challenge_trace_data. push ( * eq_at_r) ;
277- after_challenge_trace_data. push ( * s_at_row) ;
278- after_challenge_trace_data. push ( partial_sum) ;
266+ let ( after_challenge_trace, exposed_values) =
267+ metrics_span_increment ( "generate_perm_trace_time_ms" , || {
268+ let s_at_rows: Vec < EF > = gkr_instance
269+ . counts
270+ . par_row_slices ( )
271+ . zip ( gkr_instance. sigmas . par_row_slices ( ) )
272+ . map ( |( count_row, sigma_row) | {
273+ itertools:: interleave ( count_row, sigma_row)
274+ . zip ( gamma_pows)
275+ . fold ( EF :: ZERO , |acc, ( & val, & gamma_pow) | acc + gamma_pow * val)
276+ } )
277+ . collect ( ) ;
278+
279+ // TODO: Precompute these per height.
280+ let eqs_at_r = hypercube_eq_over_y ( r) ;
281+
282+ let mut partial_sum = EF :: ZERO ;
283+ let mut after_challenge_trace_data = Vec :: with_capacity ( eqs_at_r. len ( ) * 3 ) ;
284+ for ( eq_at_r, s_at_row) in eqs_at_r. iter ( ) . zip ( s_at_rows. iter ( ) ) {
285+ after_challenge_trace_data. push ( * eq_at_r) ;
286+ after_challenge_trace_data. push ( * s_at_row) ;
287+ after_challenge_trace_data. push ( partial_sum) ;
288+
289+ partial_sum += * eq_at_r * * s_at_row;
290+ }
291+ let after_challenge_trace = RowMajorMatrix :: new ( after_challenge_trace_data, 3 ) ;
292+ ( after_challenge_trace, vec ! [ partial_sum] )
293+ } ) ;
279294
280- partial_sum += * eq_at_r * * s_at_row;
281- }
282- let after_challenge_trace = RowMajorMatrix :: new ( after_challenge_trace_data, 3 ) ;
283295 (
284296 after_challenge_trace,
285- vec ! [ partial_sum ] ,
297+ exposed_values ,
286298 count_mle_claims,
287299 sigma_mle_claims,
288300 )
0 commit comments