@@ -24,7 +24,7 @@ use crate::{
24
24
} ,
25
25
parizip,
26
26
poly:: multi:: { hypercube_eq_over_y, Mle } ,
27
- utils:: metrics_span_increment ,
27
+ utils:: metrics_span ,
28
28
} ;
29
29
30
30
/// A struct that holds the interaction-related data for a single GKR instance:
@@ -165,68 +165,75 @@ where
165
165
. unwrap ( ) ;
166
166
let gamma_pows = gamma. powers ( ) . take ( 2 * max_interactions) . collect_vec ( ) ;
167
167
168
- let trace_view_per_air_filtered = interactions_per_air
168
+ let ood_point_per_instance : Vec < _ > = interactions_per_air
169
169
. iter ( )
170
170
. zip ( trace_view_per_air. iter ( ) )
171
- . filter_map ( |( interactions, view) | {
172
- if interactions. is_empty ( ) {
173
- None
174
- } else {
175
- Some ( view)
176
- }
171
+ . filter_map ( |( interactions, view) | ( !interactions. is_empty ( ) ) . then_some ( view) )
172
+ . zip ( & gkr_artifact. n_variables_by_instance )
173
+ . map ( |( trace_view, & n_vars) | {
174
+ let height = trace_view. partitioned_main [ 0 ] . height ( ) ;
175
+ let log_height = log2_strict_usize ( height) ;
176
+ let instance_ood = & ood_point[ ood_point. len ( ) - n_vars..] ;
177
+ let ( _, r) = instance_ood. split_at ( n_vars - log_height) ;
178
+ r
177
179
} )
178
- . collect_vec ( ) ;
179
-
180
- let results: Vec < _ > = parizip ! (
181
- trace_view_per_air_filtered,
182
- gkr_instances,
183
- & gkr_artifact. n_variables_by_instance
184
- )
185
- . map ( |( trace_view, gkr_instance, & n_vars) | {
186
- let height = trace_view. partitioned_main [ 0 ] . height ( ) ;
187
- let log_height = log2_strict_usize ( height) ;
188
-
189
- let instance_ood = & ood_point[ ood_point. len ( ) - n_vars..] ;
190
- let ( _, r) = instance_ood. split_at ( n_vars - log_height) ;
191
- debug_assert_eq ! ( r. len( ) , log_height) ;
180
+ . collect ( ) ;
192
181
193
- let ( after_challenge_trace, exposed_values, numer_mle_claims, denom_mle_claims) =
194
- Self :: generate_aux ( gkr_instance, & gamma_pows, r) ;
182
+ let numer_mle_claims_per_instance: Vec < Vec < _ > > =
183
+ parizip ! ( & ood_point_per_instance, gkr_instances, )
184
+ . map ( |( r, gkr_instance) | {
185
+ gkr_instance
186
+ . numerators
187
+ . transpose ( )
188
+ . par_rows_mut ( )
189
+ . map ( |row| Mle :: eval_slice_with_buf ( row, r) )
190
+ . collect ( )
191
+ } )
192
+ . collect ( ) ;
193
+
194
+ let denom_mle_claims_per_instance: Vec < Vec < _ > > =
195
+ parizip ! ( & ood_point_per_instance, gkr_instances, )
196
+ . map ( |( r, gkr_instance) | {
197
+ gkr_instance
198
+ . denominators
199
+ . transpose ( )
200
+ . par_rows_mut ( )
201
+ . map ( |row| Mle :: eval_slice_with_buf ( row, r) )
202
+ . collect ( )
203
+ } )
204
+ . collect ( ) ;
205
+
206
+ for ( numer_mle, denom_mle) in numer_mle_claims_per_instance
207
+ . iter ( )
208
+ . zip ( denom_mle_claims_per_instance. iter ( ) )
209
+ {
210
+ for ( numer_mle_claim, denom_mle_claim) in numer_mle. iter ( ) . zip ( denom_mle. iter ( ) ) {
211
+ challenger. observe_ext_element ( * numer_mle_claim) ;
212
+ challenger. observe_ext_element ( * denom_mle_claim) ;
213
+ }
214
+ }
195
215
196
- (
197
- after_challenge_trace,
198
- exposed_values,
199
- numer_mle_claims,
200
- denom_mle_claims,
201
- )
202
- } )
203
- . collect ( ) ;
216
+ let results: Vec < _ > = metrics_span ( "generate_perm_trace_time_ms" , || {
217
+ parizip ! ( & ood_point_per_instance, gkr_instances)
218
+ . map ( |( r, gkr_instance) | Self :: generate_perm_trace ( gkr_instance, & gamma_pows, r) )
219
+ . collect ( )
220
+ } ) ;
204
221
205
222
let mut results_iter = results. into_iter ( ) ;
206
223
207
- let n_airs = interactions_per_air . len ( ) ;
208
- let mut after_challenge_trace_per_air = Vec :: with_capacity ( n_airs ) ;
209
- let mut exposed_values_per_air = Vec :: with_capacity ( n_airs ) ;
210
- let mut numer_mle_claims_per_instance = Vec :: with_capacity ( n_airs ) ;
211
- let mut denom_mle_claims_per_instance = Vec :: with_capacity ( n_airs ) ;
212
-
213
- for interactions in interactions_per_air . iter ( ) {
214
- if interactions . is_empty ( ) {
215
- after_challenge_trace_per_air . push ( None ) ;
216
- exposed_values_per_air . push ( None ) ;
217
- } else {
218
- let ( after_trace , exposed , numer_mle , denom_mle ) = results_iter . next ( ) . unwrap ( ) ;
224
+ let ( after_challenge_trace_per_air , exposed_values_per_air ) : ( Vec < _ > , Vec < _ > ) =
225
+ interactions_per_air
226
+ . iter ( )
227
+ . map ( |interactions| {
228
+ if interactions . is_empty ( ) {
229
+ ( None , None )
230
+ } else {
231
+ let ( after_trace , exposed ) = results_iter . next ( ) . unwrap ( ) ;
232
+ ( Some ( after_trace ) , Some ( exposed ) )
233
+ }
234
+ } )
235
+ . unzip ( ) ;
219
236
220
- for ( numer_mle_claim, denom_mle_claim) in numer_mle. iter ( ) . zip ( denom_mle. iter ( ) ) {
221
- challenger. observe_ext_element ( * numer_mle_claim) ;
222
- challenger. observe_ext_element ( * denom_mle_claim) ;
223
- }
224
- after_challenge_trace_per_air. push ( Some ( after_trace) ) ;
225
- exposed_values_per_air. push ( Some ( exposed) ) ;
226
- numer_mle_claims_per_instance. push ( numer_mle) ;
227
- denom_mle_claims_per_instance. push ( denom_mle) ;
228
- }
229
- }
230
237
GkrAuxData {
231
238
after_challenge_trace_per_air,
232
239
exposed_values_per_air,
@@ -235,59 +242,35 @@ where
235
242
}
236
243
}
237
244
238
- fn generate_aux (
245
+ fn generate_perm_trace (
239
246
gkr_instance : & GkrLogUpInstance < EF > ,
240
247
gamma_pows : & [ EF ] ,
241
248
r : & [ EF ] ,
242
- ) -> ( DenseMatrix < EF > , Vec < EF > , Vec < EF > , Vec < EF > ) {
243
- let numer_mle_claims : Vec < _ > = gkr_instance
249
+ ) -> ( DenseMatrix < EF > , Vec < EF > ) {
250
+ let s_at_rows : Vec < EF > = gkr_instance
244
251
. numerators
245
- . transpose ( )
246
- . par_rows_mut ( )
247
- . map ( |row| Mle :: eval_slice ( row, r) )
248
- . collect ( ) ;
249
-
250
- let denom_mle_claims: Vec < _ > = gkr_instance
251
- . denominators
252
- . transpose ( )
253
- . par_rows_mut ( )
254
- . map ( |row| Mle :: eval_slice ( row, r) )
252
+ . par_row_slices ( )
253
+ . zip ( gkr_instance. denominators . par_row_slices ( ) )
254
+ . map ( |( count_row, sigma_row) | {
255
+ itertools:: interleave ( count_row, sigma_row)
256
+ . zip ( gamma_pows)
257
+ . fold ( EF :: ZERO , |acc, ( & val, & gamma_pow) | acc + gamma_pow * val)
258
+ } )
255
259
. collect ( ) ;
256
260
257
- let ( after_challenge_trace, exposed_values) =
258
- metrics_span_increment ( "generate_perm_trace_time_ms" , || {
259
- let s_at_rows: Vec < EF > = gkr_instance
260
- . numerators
261
- . par_row_slices ( )
262
- . zip ( gkr_instance. denominators . par_row_slices ( ) )
263
- . map ( |( count_row, sigma_row) | {
264
- itertools:: interleave ( count_row, sigma_row)
265
- . zip ( gamma_pows)
266
- . fold ( EF :: ZERO , |acc, ( & val, & gamma_pow) | acc + gamma_pow * val)
267
- } )
268
- . collect ( ) ;
261
+ // TODO: Precompute these per height.
262
+ let eqs_at_r = hypercube_eq_over_y ( r) ;
269
263
270
- // TODO: Precompute these per height.
271
- let eqs_at_r = hypercube_eq_over_y ( r) ;
264
+ let mut partial_sum = EF :: ZERO ;
265
+ let mut after_challenge_trace_data = Vec :: with_capacity ( eqs_at_r. len ( ) * 3 ) ;
266
+ for ( eq_at_r, s_at_row) in eqs_at_r. iter ( ) . zip ( s_at_rows. iter ( ) ) {
267
+ after_challenge_trace_data. push ( * eq_at_r) ;
268
+ after_challenge_trace_data. push ( * s_at_row) ;
269
+ after_challenge_trace_data. push ( partial_sum) ;
272
270
273
- let mut partial_sum = EF :: ZERO ;
274
- let mut after_challenge_trace_data = Vec :: with_capacity ( eqs_at_r. len ( ) * 3 ) ;
275
- for ( eq_at_r, s_at_row) in eqs_at_r. iter ( ) . zip ( s_at_rows. iter ( ) ) {
276
- after_challenge_trace_data. push ( * eq_at_r) ;
277
- after_challenge_trace_data. push ( * s_at_row) ;
278
- after_challenge_trace_data. push ( partial_sum) ;
279
-
280
- partial_sum += * eq_at_r * * s_at_row;
281
- }
282
- let after_challenge_trace = RowMajorMatrix :: new ( after_challenge_trace_data, 3 ) ;
283
- ( after_challenge_trace, vec ! [ partial_sum] )
284
- } ) ;
285
-
286
- (
287
- after_challenge_trace,
288
- exposed_values,
289
- numer_mle_claims,
290
- denom_mle_claims,
291
- )
271
+ partial_sum += * eq_at_r * * s_at_row;
272
+ }
273
+ let after_challenge_trace = RowMajorMatrix :: new ( after_challenge_trace_data, 3 ) ;
274
+ ( after_challenge_trace, vec ! [ partial_sum] )
292
275
}
293
276
}
0 commit comments