-
Notifications
You must be signed in to change notification settings - Fork 271
Feat/streaming_prover #1114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feat/streaming_prover #1114
Conversation
- Merging working spartan outer into the new api
- Merging working spartan outer into the new api. - things don't compile yet
R1CS eval changes brought in
Tests seem to pass. Double check with gitools if things are merged.
Tests seem to pass. Double check with gitools if things are merged.
Tests seem to pass. Think I have everything.
Tests pass (at least the ones that were passing)
Typos on expanding table
Optimising stream to linear started
The linear schedule is fine, but the streaming schedule re-computation is highly suboptimal.
A much faster re-computation of Az/Bz with parallel code
Stream to linear is pretty much as fast as it needs to be.
jolt-core/src/zkvm/spartan/outer.rs
Outdated
| let grid_az_ptr = grid_az.as_mut_ptr() as usize; | ||
| let grid_bz_ptr = grid_bz.as_mut_ptr() as usize; | ||
| let chunk_size = 4096; | ||
| let num_chunks = (jlen + chunk_size - 1) / chunk_size; | ||
| (0..num_chunks).into_par_iter().for_each(move |chunk_idx| { | ||
| let start = chunk_idx * chunk_size; | ||
| let end = (start + chunk_size).min(jlen); | ||
|
|
||
| let az_ptr = grid_az_ptr as *mut F; | ||
| let bz_ptr = grid_bz_ptr as *mut F; | ||
|
|
||
| for j in start..end { | ||
| let az_j = acc_az[j].barrett_reduce(); | ||
| let bz_first_j = acc_bz_first[j].barrett_reduce(); | ||
| let bz_second_j = acc_bz_second[j].barrett_reduce(); | ||
|
|
||
| unsafe { | ||
| *az_ptr.add(j) = az_j; | ||
| *bz_ptr.add(j) = bz_first_j + bz_second_j; | ||
| } | ||
| } | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a potential race condition in this parallel reduction. Each thread is using the same raw pointers (az_ptr and bz_ptr) and indexing with the absolute j value rather than a chunk-relative index.
To fix this, either:
-
Use slice-based access instead of raw pointers:
grid_az[j] = az_j; -
Or if using pointers for performance, adjust the index to be relative to the start of each chunk:
*az_ptr.add(j) = az_j;should be:
*az_ptr.add(start + (j - start)) = az_j;or more simply:
*az_ptr.add(j) = az_j;The current approach could lead to threads writing to overlapping memory locations if the chunk boundaries aren't calculated correctly.
Spotted by Graphite Agent
Is this helpful? React 👍 or 👎 to let us know.
-- A functional version of this has been implemented, but i'll move the design around. But this is at least correct.
quangvdao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! There are a number of changes I'd like to land before this is merge-able:
- Can you clean up all the TODOs, fix the performance regression, and streamline the code?
- When switching to linear time, ideally we should compute the evals of that round as we materialize, rather than doing another pass over the bound evals
- More generally, I think the following flow might be preferable for computing prover message (I have talked about this a number of times I think):
fn compute_message(...) -> {
// 1. Build the multiquadratic eval for this window, if it does not exist (or number of variables have fallen to 0)
// 1.a) If we are streaming from trace, and we don't want to materialize for this window (i.e. not the last window), call a `stream_eval` kernel
// 1.b) If we are streaming from trace, and we want to materialize for this window (i.e. last window of streaming), call a `stream_materialize_eval` kernel
// 1.c) If we have materialized, call a `materialize_eval` kernel
// 2. From the multiquadratic poly, derive the current round's uni poly.
}
This allows us to be flexible on the number of rounds to compute, even when we have the materialized values. it is rare that we will use it on CPU, but we will most certainly use it on GPU, and it would be good to have a reference CPU version. It will also collapse into the linear-time version, if you compute smartly (the multiquadratic evals should omit / not compute the eval at (1, 1, ..., 1)).
NEXT: Proper parallelisation Fused materialisation (this is currently ok for grids of window size 1); but not generalised.
…nows. Well that's next.
…nows. Well that's next.
For now we've adapted the API to suit the above. |
moodlezoup
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which of the new OuterRemainingSumcheckProver methods are going to be used in practice? I see that several method names are currently prefixed with an underscore, and some appear to be serial algorithms with a parallel counterpart. Let's delete anything that we're not going to use (including some of the stuff that's currently commented out).
| let eval_at_inf = self.evals[old_base_idx + 2]; // z_0 = ∞ | ||
|
|
||
| self.evals[new_idx] = | ||
| eval_at_0 * (one - r) + eval_at_1 * r + eval_at_inf * r * (r - one); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can pull r * (r - one) out of the for loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll confirm the ones that'll be used after benching them. They'll likely be some specialised ones for small window sizes for speed.
I recommend keeping some of the serial implementations -- as the parallel code can often be hard to follow -- and theres's now a fair amount of nesting.
| let lazy_trace_iter_ = lazy_trace_iter.clone(); | ||
| // NOTE: this will materialise the trace in full | ||
| let trace: Vec<Cycle> = lazy_trace_iter.by_ref().collect(); | ||
| //let trace = lazy_trace_iter.by_ref(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: delete
| // Clone and advance to target cycle | ||
| let mut iter = checkpoints[checkpoint_idx].clone(); | ||
| for _ in 0..offset { | ||
| iter.next(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should only clone each checkpoint once for the checkpoint_interval cycles it's responsible for (
| // Streaming windows are not defined for HighToLow in the current | ||
| // Spartan code paths; return neutral head tables. | ||
| (&self.one_table, &self.one_table) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this should be unimplemented! for now?
| } | ||
| BindingOrder::HighToLow => { | ||
| // Not used for the outer Spartan streaming code. | ||
| vec![F::one()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, prefer unimplemented!
|
|
||
| /// Degree bound of the sumcheck round polynomials for [`OuterRemainingSumcheckVerifier`]. | ||
| const OUTER_REMAINING_DEGREE_BOUND: usize = 3; | ||
| const INFINITY: usize = 2; // 2 represents ∞ in base-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of confusing to me, doesn't 2 represent 2 in base-3? 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. This is a remnant from previous naming -- the actual value does not matter. I'll change this, thanks
| first_round_evals: (F, F), | ||
| #[allocative(skip)] | ||
| params: OuterRemainingSumcheckParams<F>, | ||
| lagrange_evals_r0: [F; 10], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does the number 10 come from? would be good to use a named const here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be OUTER_SKIP_DOMAIN_SIZE
jolt-core/src/zkvm/spartan/outer.rs
Outdated
|
|
||
| let grid_az_ptr = grid_az.as_mut_ptr() as usize; | ||
| let grid_bz_ptr = grid_bz.as_mut_ptr() as usize; | ||
| let chunk_size = 4096; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be min(4096, jlen)?
jolt-core/src/zkvm/spartan/outer.rs
Outdated
| unsafe { | ||
| *az_ptr.add(j) = az_j; | ||
| *bz_ptr.add(j) = bz_first_j + bz_second_j; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of this unsafe pointer stuff, can't we just use grid_az.par_chunks_mut(chunk_size).zip(grid_bz.par_chunk_mut(chunk_size)).enumerate() or something?
| let mut acc_bz_first = vec![Acc6S::<F>::zero(); jlen]; | ||
| let mut acc_bz_second = vec![Acc7S::<F>::zero(); jlen]; | ||
|
|
||
| if !parallel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we ever going to use the serial version in practice?
A generalised sum-check API.
NOTE: Parallel performance of compute message sub-optimal right now due to increased complexity of building general-sized multi-quadratic for general window_sizes >=1