11#pragma once
22
3+ #include " barretenberg/common/zip_view.hpp"
34#include " barretenberg/constants.hpp"
45#include " barretenberg/polynomials/univariate.hpp"
56#include < array>
@@ -10,15 +11,16 @@ namespace bb {
1011 * @brief Stores ZK masking values for witness polynomials and manages their folding across sumcheck rounds.
1112 *
1213 * @details When witness polynomials are allocated to trace_active_range (not full dyadic_size), the masking
13- * values at the last NUM_MASKED_ROWS positions are stored here as small "tail" polynomials. This struct:
14- * 1. Holds the tail polynomials (NUM_MASKED_ROWS coefficients at positions {n-3, n-2, n-1}, full virtual_size )
15- * 2. Tracks which entities are masked via an AllEntities<bool> flag
14+ * values at the tail positions are stored here as small polynomials. This struct:
15+ * 1. Holds tail polynomials (NUM_MASKED_ROWS coefficients; unshifted at {n-3,n-2,n-1}, shifted at {n-4,n-3,n-2} )
16+ * 2. Tracks which entities are masked via AllEntities<bool> (used by compute_disabled_contribution)
1617 * 3. Manages folded masking values across sumcheck rounds
1718 * 4. Computes claimed evaluation corrections via Lagrange products of challenges
18- * 5. Provides tail polynomials for PCS batching
19+ * 5. Stores tails for PCS commitment adjustment and Gemini batching
1920 *
2021 * Only used for flavors with UseRowDisablingPolynomial (not Translator, which uses a different ZK technique).
21- * Uses the AllEntities pattern: parallel structures indexed identically to ProverPolynomials.
22+ * Uses the AllEntities pattern: callers access tails by named field (e.g. tails.w_l) or via
23+ * get_masked()/get_shifted() for iteration — no pointer-matching lookups needed.
2224 */
2325template <typename Flavor> struct MaskingTailData {
2426 using FF = typename Flavor::FF;
@@ -34,7 +36,7 @@ template <typename Flavor> struct MaskingTailData {
3436 AllEntities<Polynomial> tails{};
3537
3638 // Folded masking values tracked across sumcheck rounds.
37- // [0]=even position value, [1]=odd position value.
39+ // [0]=even position value, [1]=odd position value. Default {0,0} for non-masked entities.
3840 AllEntities<std::array<FF, 2 >> folded{};
3941
4042 // Global folding state: 0 = not yet folded, 2 = after round 0, 1 = after round 1+.
@@ -53,33 +55,29 @@ template <typename Flavor> struct MaskingTailData {
5355 * / get_shifted() which are guaranteed parallel arrays.
5456 * Call once before any commits (e.g., at start of OinkProver::prove()).
5557 */
56- template < typename ProverPolynomials> void register_all_masked_polys ([[maybe_unused]] ProverPolynomials& polys )
58+ void register_all_masked_polys ()
5759 {
5860 size_t start = dyadic_size - NUM_MASKED_ROWS;
5961
6062 // 1. Mark masked entities and generate random tail values
61- auto masked_flags = is_masked.get_masked ();
62- auto masked_tails = tails.get_masked ();
63- for (size_t i = 0 ; i < masked_flags.size (); i++) {
64- masked_flags[i] = true ;
65- masked_tails[i] = Polynomial (NUM_MASKED_ROWS, dyadic_size, start);
63+ for (auto [flag, tail] : zip_view (is_masked.get_masked (), tails.get_masked ())) {
64+ flag = true ;
65+ tail = Polynomial (NUM_MASKED_ROWS, dyadic_size, start);
6666 for (size_t j = 0 ; j < NUM_MASKED_ROWS; j++) {
67- masked_tails[i] .at (start + j) = FF::random_element ();
67+ tail .at (start + j) = FF::random_element ();
6868 }
6969 }
7070 active = true ;
7171
7272 // 2. Derive shifted tails: get_to_be_shifted() and get_shifted() are parallel arrays.
7373 // All to-be-shifted sources are in get_masked(), so all shifted entries are active.
74- auto src_tails = tails.get_to_be_shifted ();
75- auto shifted_flags = is_masked.get_shifted ();
76- auto shifted_tails = tails.get_shifted ();
7774 size_t shift_start = start - 1 ;
78- for (size_t s = 0 ; s < shifted_tails.size (); s++) {
79- shifted_flags[s] = true ;
80- shifted_tails[s] = Polynomial (NUM_MASKED_ROWS, dyadic_size, shift_start);
75+ for (auto [src_tail, shifted_flag, shifted_tail] :
76+ zip_view (tails.get_to_be_shifted (), is_masked.get_shifted (), tails.get_shifted ())) {
77+ shifted_flag = true ;
78+ shifted_tail = Polynomial (NUM_MASKED_ROWS, dyadic_size, shift_start);
8179 for (size_t k = 0 ; k < NUM_MASKED_ROWS; k++) {
82- shifted_tails[s] .at (shift_start + k) = src_tails[s] .at (start + k);
80+ shifted_tail .at (shift_start + k) = src_tail .at (start + k);
8381 }
8482 }
8583 }
@@ -89,7 +87,8 @@ template <typename Flavor> struct MaskingTailData {
8987 * @param challenge The round challenge u_i.
9088 * @param round_idx The sumcheck round index (0-based).
9189 * @param round_size The round size BEFORE halving (2^{d-i}).
92- * @param pe Pointer to PE multivariates (needed for rounds 2+).
90+ * @param pe Pointer to PE multivariates (needed for rounds 2+: the even-position
91+ * value comes from the partially-evaluated table, not from folded masking state).
9392 */
9493 template <typename PolynomialCollection>
9594 void fold_masking_values (FF challenge,
@@ -105,49 +104,45 @@ template <typename Flavor> struct MaskingTailData {
105104 size_t start = dyadic_size - NUM_MASKED_ROWS;
106105
107106 // Unshifted masked: positions {n-3, n-2, n-1} have values {m0, m1, m2}, position n-4 = 0
108- auto masked_tails = tails.get_masked ();
109- auto masked_folded = folded.get_masked ();
110- for (size_t i = 0 ; i < masked_tails.size (); i++) {
111- FF m0 = masked_tails[i].at (start);
112- FF m1 = masked_tails[i].at (start + 1 );
113- FF m2 = masked_tails[i].at (start + 2 );
114- masked_folded[i][0 ] = challenge * m0;
115- masked_folded[i][1 ] = m1 + challenge * (m2 - m1);
107+ for (auto [tail, f] : zip_view (tails.get_masked (), folded.get_masked ())) {
108+ FF m0 = tail.at (start);
109+ FF m1 = tail.at (start + 1 );
110+ FF m2 = tail.at (start + 2 );
111+ f[0 ] = challenge * m0;
112+ f[1 ] = m1 + challenge * (m2 - m1);
116113 }
117114
118115 // Shifted: positions {n-4, n-3, n-2} have values {m0, m1, m2}, position n-1 = 0
119- auto shifted_tails = tails.get_shifted ();
120- auto shifted_folded = folded.get_shifted ();
121- for (size_t s = 0 ; s < shifted_tails.size (); s++) {
122- FF m0 = shifted_tails[s].at (start - 1 );
123- FF m1 = shifted_tails[s].at (start);
124- FF m2 = shifted_tails[s].at (start + 1 );
125- shifted_folded[s][0 ] = m0 + challenge * (m1 - m0);
126- shifted_folded[s][1 ] = m2 * (FF::one () - challenge);
116+ for (auto [tail, f] : zip_view (tails.get_shifted (), folded.get_shifted ())) {
117+ FF m0 = tail.at (start - 1 );
118+ FF m1 = tail.at (start);
119+ FF m2 = tail.at (start + 1 );
120+ f[0 ] = m0 + challenge * (m1 - m0);
121+ f[1 ] = m2 * (FF::one () - challenge);
127122 }
128123 folded_count = 2 ;
129124 } else if (round_idx == 1 ) {
130125 // Same formula for both unshifted and shifted: collapse two folded values into one
131- auto fold_round1 = [&](auto folded_refs) {
132- for (size_t i = 0 ; i < folded_refs. size (); i++ ) {
133- folded_refs[i][ 0 ] += challenge * (folded_refs[i][ 1 ] - folded_refs[i] [0 ]);
126+ auto fold = [&](auto folded_refs) {
127+ for (auto & f : folded_refs) {
128+ f[ 0 ] += challenge * (f[ 1 ] - f [0 ]);
134129 }
135130 };
136- fold_round1 (folded.get_masked ());
137- fold_round1 (folded.get_shifted ());
131+ fold (folded.get_masked ());
132+ fold (folded.get_shifted ());
138133 folded_count = 1 ;
139134 } else {
140135 BB_ASSERT (pe != nullptr );
141136 size_t even_pos = round_size - 2 ;
142137 // Interpolate between PE value and folded value
143- auto fold_round2_plus = [&](auto folded_refs, auto pe_refs) {
144- for (size_t i = 0 ; i < folded_refs. size (); i++ ) {
145- FF even_val = pe_refs[i] [even_pos];
146- folded_refs[i][ 0 ] = even_val + challenge * (folded_refs[i] [0 ] - even_val);
138+ auto fold = [&](auto folded_refs, auto pe_refs) {
139+ for (auto [f, p] : zip_view (folded_refs, pe_refs) ) {
140+ FF even_val = p [even_pos];
141+ f[ 0 ] = even_val + challenge * (f [0 ] - even_val);
147142 }
148143 };
149- fold_round2_plus (folded.get_masked (), pe->get_masked ());
150- fold_round2_plus (folded.get_shifted (), pe->get_shifted ());
144+ fold (folded.get_masked (), pe->get_masked ());
145+ fold (folded.get_shifted (), pe->get_shifted ());
151146 }
152147 }
153148
@@ -168,24 +163,20 @@ template <typename Flavor> struct MaskingTailData {
168163 size_t start = dyadic_size - NUM_MASKED_ROWS;
169164
170165 // Unshifted masked: Lagrange basis at positions {n-3, n-2, n-1}
171- auto masked_evals = evaluations.get_masked ();
172- auto masked_tails = tails.get_masked ();
173- for (size_t i = 0 ; i < masked_tails.size (); i++) {
174- FF m0 = masked_tails[i].at (start);
175- FF m1 = masked_tails[i].at (start + 1 );
176- FF m2 = masked_tails[i].at (start + 2 );
177- masked_evals[i] += common * (m0 * u0 * (FF::one () - u1) + m1 * (FF::one () - u0) * u1 + m2 * u0 * u1);
166+ for (auto [eval, tail] : zip_view (evaluations.get_masked (), tails.get_masked ())) {
167+ FF m0 = tail.at (start);
168+ FF m1 = tail.at (start + 1 );
169+ FF m2 = tail.at (start + 2 );
170+ eval += common * (m0 * u0 * (FF::one () - u1) + m1 * (FF::one () - u0) * u1 + m2 * u0 * u1);
178171 }
179172
180173 // Shifted: Lagrange basis at positions {n-4, n-3, n-2}
181- auto shifted_evals = evaluations.get_shifted ();
182- auto shifted_tails = tails.get_shifted ();
183- for (size_t s = 0 ; s < shifted_tails.size (); s++) {
184- FF m0 = shifted_tails[s].at (start - 1 );
185- FF m1 = shifted_tails[s].at (start);
186- FF m2 = shifted_tails[s].at (start + 1 );
187- shifted_evals[s] += common * (m0 * (FF::one () - u0) * (FF::one () - u1) + m1 * u0 * (FF::one () - u1) +
188- m2 * (FF::one () - u0) * u1);
174+ for (auto [eval, tail] : zip_view (evaluations.get_shifted (), tails.get_shifted ())) {
175+ FF m0 = tail.at (start - 1 );
176+ FF m1 = tail.at (start);
177+ FF m2 = tail.at (start + 1 );
178+ eval += common * (m0 * (FF::one () - u0) * (FF::one () - u1) + m1 * u0 * (FF::one () - u1) +
179+ m2 * (FF::one () - u0) * u1);
189180 }
190181 }
191182
@@ -202,20 +193,18 @@ template <typename Flavor> struct MaskingTailData {
202193 return ;
203194 }
204195
205- auto masked_polys = prover_polynomials.get_masked ();
206- auto masked_tails = tails.get_masked ();
207-
208- for (size_t i = 0 ; i < masked_polys.size (); i++) {
209- const auto & poly = masked_polys[i];
196+ // Pointer-matching against batcher lists is needed here since the batcher is an external
197+ // structure without flavor-aware getters.
198+ for (auto [poly, tail] : zip_view (prover_polynomials.get_masked (), tails.get_masked ())) {
210199 for (size_t u = 0 ; u < batcher.unshifted .size (); u++) {
211200 if (batcher.unshifted [u].data () == poly.data ()) {
212- batcher.add_unshifted_tail (u, Polynomial (masked_tails[i] ));
201+ batcher.add_unshifted_tail (u, Polynomial (tail ));
213202 break ;
214203 }
215204 }
216205 for (size_t s = 0 ; s < batcher.to_be_shifted_by_one .size (); s++) {
217206 if (batcher.to_be_shifted_by_one [s].data () == poly.data ()) {
218- batcher.add_shifted_tail (s, Polynomial (masked_tails[i] ));
207+ batcher.add_shifted_tail (s, Polynomial (tail ));
219208 break ;
220209 }
221210 }
0 commit comments