diff --git a/Cargo.lock b/Cargo.lock index b2eba1c80d..1641c8b186 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4915,6 +4915,7 @@ dependencies = [ "test-log", "thiserror 1.0.69", "tracing", + "zerocopy 0.8.25", ] [[package]] @@ -5465,6 +5466,7 @@ dependencies = [ "serde-big-array", "strum", "test-case", + "zerocopy 0.8.25", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3f79027b71..c881ead3e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -221,6 +221,7 @@ rrs-lib = "0.1.0" rand = { version = "0.8.5", default-features = false } hex = { version = "0.4.3", default-features = false } serde-big-array = "0.5.1" +zerocopy = "0.8.25" # default-features = false for no_std for use in guest programs itertools = { version = "0.14.0", default-features = false } diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 80e6794b48..391bdc9e1f 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -35,6 +35,7 @@ eyre.workspace = true derivative.workspace = true static_assertions.workspace = true getset.workspace = true +zerocopy = { workspace = true, features = ["derive"] } [dev-dependencies] test-log.workspace = true diff --git a/crates/vm/src/arch/integration_api.rs b/crates/vm/src/arch/integration_api.rs index d53ed075fe..156630d560 100644 --- a/crates/vm/src/arch/integration_api.rs +++ b/crates/vm/src/arch/integration_api.rs @@ -1,4 +1,4 @@ -use std::{array::from_fn, borrow::Borrow, marker::PhantomData, sync::Arc}; +use std::{any::type_name, array::from_fn, borrow::Borrow, marker::PhantomData, sync::Arc}; use openvm_circuit_primitives::utils::next_power_of_two_or_zero; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -6,7 +6,7 @@ use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_field::{Field, FieldAlgebra, PackedValue, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, prover::types::AirProofInput, @@ -14,6 +14,7 @@ use openvm_stark_backend::{ AirRef, Chip, ChipUsageGetter, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use zerocopy::FromBytes; use super::{ExecutionState, InsExecutorE1, InstructionExecutor, Result, VmStateMut}; use crate::system::memory::{ @@ -54,48 +55,6 @@ pub trait VmAdapterAir: BaseAir { fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var; } -// TODO: delete -/// Trait to be implemented on primitive chip to integrate with the machine. -pub trait VmCoreChip> { - /// Minimum data that must be recorded to be able to generate trace for one row of - /// `PrimitiveAir`. - type Record: Send + Serialize + DeserializeOwned; - /// The primitive AIR with main constraints that do not depend on memory and other - /// architecture-specifics. - type Air: BaseAirWithPublicValues + Clone; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)>; - - fn get_opcode_name(&self, opcode: usize) -> String; - - /// Populates `row_slice` with values corresponding to `record`. - /// The provided `row_slice` will have length equal to `self.air().width()`. - /// This function will be called for each row in the trace which is being used, and all other - /// rows in the trace will be filled with zeroes. - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record); - - /// Returns a list of public values to publish. - fn generate_public_values(&self) -> Vec { - vec![] - } - - fn air(&self) -> &Self::Air; - - /// Finalize the trace, especially the padded rows if the all-zero rows don't satisfy the - /// constraints. This is done **after** records are consumed and the trace matrix is - /// generated. Most implementations should just leave the default implementation if padding - /// with rows of all 0s satisfies the constraints. - fn finalize(&self, _trace: &mut RowMajorMatrix, _num_records: usize) { - // do nothing by default - } -} - pub trait VmCoreAir: BaseAirWithPublicValues where AB: AirBuilder, @@ -127,23 +86,6 @@ where } } -// TODO: delete -pub struct AdapterRuntimeContext> { - /// Leave as `None` to allow the adapter to decide the `to_pc` automatically. - pub to_pc: Option, - pub writes: I::Writes, -} - -impl> AdapterRuntimeContext { - /// Leave `to_pc` as `None` to allow the adapter to decide the `to_pc` automatically. - pub fn without_pc(writes: impl Into) -> Self { - Self { - to_pc: None, - writes: writes.into(), - } - } -} - pub struct AdapterAirContext> { /// Leave as `None` to allow the adapter to decide the `to_pc` automatically. pub to_pc: Option, @@ -152,42 +94,76 @@ pub struct AdapterAirContext> { pub instruction: I::ProcessedInstruction, } +/// Given some minimum metadata of type `Layout` that specifies the record size, the `RecordArena` +/// should allocate a buffer, of size possibly larger than the record, and then return mutable +/// pointers to the record within the buffer. +pub trait RecordArena<'a, Layout, RecordMut> { + /// Allocates underlying buffer and returns a mutable reference `RecordMut`. + /// Note that calling this function may not call an underlying memory allocation as the record + /// arena may be virtual. + fn alloc(&'a mut self, layout: Layout) -> RecordMut; +} + +/// ZST to represent empty layout. Used when the layout can be inferred from other context (such as +/// AIR or record types). +pub struct EmptyLayout; + /// Interface for trace generation of a single instruction.The trace is provided as a mutable /// buffer during both instruction execution and trace generation. /// It is expected that no additional memory allocation is necessary and the trace buffer /// is sufficient, with possible overwriting. pub trait TraceStep { - fn execute( + type RecordLayout; + type RecordMut<'a>; + + fn execute<'buf, RA>( &mut self, state: VmStateMut, CTX>, instruction: &Instruction, - // TODO(ayush): combine to a single struct - trace: &mut [F], - trace_offset: &mut usize, - // TODO(ayush): move air inside step and remove width - width: usize, - ) -> Result<()>; + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>; + + /// Returns a list of public values to publish. + fn generate_public_values(&self) -> Vec { + vec![] + } + + /// Displayable opcode name for logging and debugging purposes. + fn get_opcode_name(&self, opcode: usize) -> String; +} + +// TODO[jpw]: this might be temporary trait before moving trace to CTX +pub trait RowMajorMatrixArena { + fn with_capacity(height: usize, width: usize) -> Self; + fn width(&self) -> usize; + fn trace_offset(&self) -> usize; + fn into_matrix(self) -> RowMajorMatrix; +} +// TODO[jpw]: revisit if this trait makes sense +pub trait TraceFiller { /// Populates `trace`. This function will always be called after - /// [`TraceStep::execute`], so the `trace` should already contain context necessary to - /// fill in the rest of it. + /// [`TraceStep::execute`], so the `trace` should already contain the records necessary to fill + /// in the rest of it. // TODO(ayush): come up with a better abstraction for chips that fill a dynamic number of rows fn fill_trace( &self, mem_helper: &MemoryAuxColsFactory, - trace: &mut [F], - width: usize, + trace: &mut RowMajorMatrix, rows_used: usize, ) where Self: Send + Sync, - F: Send + Sync, + F: Send + Sync + Clone, { - trace[..rows_used * width] + let width = trace.width(); + trace.values[..rows_used * width] .par_chunks_exact_mut(width) .for_each(|row_slice| { self.fill_trace_row(mem_helper, row_slice); }); - trace[rows_used * width..] + trace.values[rows_used * width..] .par_chunks_exact_mut(width) .for_each(|row_slice| { self.fill_dummy_trace_row(mem_helper, row_slice); @@ -196,11 +172,11 @@ pub trait TraceStep { /// Populates `row_slice`. This function will always be called after /// [`TraceStep::execute`], so the `row_slice` should already contain context necessary to - /// fill in the rest of the row. This function will be called for each row in the trace which is - /// being used, and all other rows in the trace will be filled with zeroes. + /// fill in the rest of the row. This function will be called for each row in the trace which + /// is being used, and for all other rows in the trace see `fill_dummy_trace_row`. /// /// The provided `row_slice` will have length equal to the width of the AIR. - fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + fn fill_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, _row_slice: &mut [F]) { unreachable!("fill_trace_row is not implemented") } @@ -208,54 +184,95 @@ pub trait TraceStep { /// By default the trace is padded with empty (all 0) rows to make the height a power of 2. /// /// The provided `row_slice` will have length equal to the width of the AIR. - fn fill_dummy_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + fn fill_dummy_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, _row_slice: &mut [F]) { // By default, the row is filled with zeroes } - /// Returns a list of public values to publish. - fn generate_public_values(&self) -> Vec { - vec![] +} + +// TEMP[jpw]: buffer should be inside CTX +pub struct MatrixRecordArena { + pub trace_buffer: Vec, + // TODO(ayush): width should be a constant? + pub width: usize, + pub trace_offset: usize, +} + +impl MatrixRecordArena { + pub fn alloc_single_row(&mut self) -> &mut [u8] { + let start = self.trace_offset; + self.trace_offset += self.width; + let row_slice = &mut self.trace_buffer[start..self.trace_offset]; + let size = size_of_val(row_slice); + let ptr = row_slice as *mut [F] as *mut u8; + // SAFETY: + // - `ptr` is non-null + // - `size` is correct + // - alignment of `u8` is always satisfied + unsafe { &mut *std::ptr::slice_from_raw_parts_mut(ptr, size) } } +} - /// Displayable opcode name for logging and debugging purposes. - fn get_opcode_name(&self, opcode: usize) -> String; +impl RowMajorMatrixArena for MatrixRecordArena { + fn with_capacity(height: usize, width: usize) -> Self { + let trace_buffer = F::zero_vec(height * width); + Self { + trace_buffer, + width, + trace_offset: 0, + } + } + + fn width(&self) -> usize { + self.width + } + + fn trace_offset(&self) -> usize { + self.trace_offset + } + + fn into_matrix(self) -> RowMajorMatrix { + RowMajorMatrix::new(self.trace_buffer, self.width) + } } // TODO(ayush): rename to ChipWithExecutionContext or something -pub struct NewVmChipWrapper { +pub struct NewVmChipWrapper { pub air: AIR, pub step: STEP, - pub trace_buffer: Vec, - // TODO(ayush): width should be a constant? - width: usize, - buffer_idx: usize, + pub arena: RA, mem_helper: SharedMemoryHelper, } -impl NewVmChipWrapper +impl NewVmChipWrapper where F: Field, AIR: BaseAir, + RA: RowMajorMatrixArena, { pub fn new(air: AIR, step: STEP, height: usize, mem_helper: SharedMemoryHelper) -> Self { - assert!(height == 0 || height.is_power_of_two()); let width = air.width(); - let trace_buffer = F::zero_vec(height * width); + assert!(height == 0 || height.is_power_of_two()); + assert!( + align_of::() >= align_of::(), + "type {} should have at least alignment of u32", + type_name::() + ); + let arena = RA::with_capacity(height, width); Self { air, step, - trace_buffer, - width, - buffer_idx: 0, + arena, mem_helper, } } } -impl InstructionExecutor for NewVmChipWrapper +impl InstructionExecutor for NewVmChipWrapper where F: PrimeField32, STEP: TraceStep // TODO: CTX? + StepExecutorE1, + for<'buf> RA: RecordArena<'buf, STEP::RecordLayout, STEP::RecordMut<'buf>>, { fn execute( &mut self, @@ -269,13 +286,7 @@ where memory: &mut memory.memory, ctx: &mut (), }; - self.step.execute( - state, - instruction, - &mut self.trace_buffer, - &mut self.buffer_idx, - self.width, - )?; + self.step.execute(state, instruction, &mut self.arena)?; Ok(ExecutionState { pc, @@ -292,92 +303,98 @@ where // - `Air` is an `Air` for all `AB: AirBuilder`s needed by stark-backend // which is equivalent to saying it implements AirRef // The where clauses to achieve this statement is unfortunately really verbose. -impl Chip for NewVmChipWrapper, AIR, STEP> +impl Chip for NewVmChipWrapper, AIR, STEP, RA> where SC: StarkGenericConfig, Val: PrimeField32, - STEP: TraceStep, ()> + Send + Sync, + STEP: TraceStep, ()> + TraceFiller, ()> + Send + Sync, AIR: Clone + AnyRap + 'static, + RA: RowMajorMatrixArena>, { fn air(&self) -> AirRef { Arc::new(self.air.clone()) } - fn generate_air_proof_input(mut self) -> AirProofInput { - assert_eq!(self.buffer_idx % self.width, 0); - let rows_used = self.current_trace_height(); + fn generate_air_proof_input(self) -> AirProofInput { + let width = self.arena.width(); + assert_eq!(self.arena.trace_offset() % width, 0); + let rows_used = self.arena.trace_offset() / width; let height = next_power_of_two_or_zero(rows_used); + let mut trace = self.arena.into_matrix(); // This should be automatic since trace_buffer's height is a power of two: - assert!(height.checked_mul(self.width).unwrap() <= self.trace_buffer.len()); - self.trace_buffer.truncate(height * self.width); + assert!(height.checked_mul(width).unwrap() <= trace.values.len()); + trace.values.truncate(height * width); let mem_helper = self.mem_helper.as_borrowed(); - self.step - .fill_trace(&mem_helper, &mut self.trace_buffer, self.width, rows_used); + self.step.fill_trace(&mem_helper, &mut trace, rows_used); drop(self.mem_helper); - let trace = RowMajorMatrix::new(self.trace_buffer, self.width); - // self.inner.finalize(&mut trace, num_records); AirProofInput::simple(trace, self.step.generate_public_values()) } } -impl ChipUsageGetter for NewVmChipWrapper +impl ChipUsageGetter for NewVmChipWrapper where C: Sync, + RA: RowMajorMatrixArena, { fn air_name(&self) -> String { get_air_name(&self.air) } fn current_trace_height(&self) -> usize { - self.buffer_idx / self.width + self.arena.trace_offset() / self.arena.width() } fn trace_width(&self) -> usize { - self.width + self.arena.width() } } -// TODO[jpw]: switch read,write to store into abstract buffer, then fill_trace_row using buffer /// A helper trait for expressing generic state accesses within the implementation of /// [TraceStep]. Note that this is only a helper trait when the same interface of state access /// is reused or shared by multiple implementations. It is not required to implement this trait if /// it is easier to implement the [TraceStep] trait directly without this trait. pub trait AdapterTraceStep { - /// Adapter row width - const WIDTH: usize; type ReadData; type WriteData; - /// The minimal amount of information needed to generate the sub-row of the trace matrix. - /// This type has a lifetime so other context, such as references to other chips, can be - /// provided. - type TraceContext<'a> + // @dev This can either be a &mut _ type or a struct with &mut _ fields. + // The latter is helpful if we want to directly write certain values in place into a trace + // matrix. + type RecordMut<'a> where Self: 'a; - fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]); + // /// The minimal amount of information needed to generate the sub-row of the trace matrix. + // /// This type has a lifetime so other context, such as references to other chips, can be + // /// provided. + // type TraceContext<'a> + // where + // Self: 'a; + + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>); fn read( &self, memory: &mut TracingMemory, instruction: &Instruction, - adapter_row: &mut [F], + record: &mut Self::RecordMut<'_>, ) -> Self::ReadData; fn write( &self, memory: &mut TracingMemory, instruction: &Instruction, - adapter_row: &mut [F], data: &Self::WriteData, + record: &mut Self::RecordMut<'_>, ); +} + +// NOTE[jpw]: cannot reuse `TraceSubRowGenerator` trait because we need associated constant +// `WIDTH`. +pub trait AdapterTraceFiller { + /// Adapter sub-air column width + const WIDTH: usize; - // Note[jpw]: should we reuse TraceSubRowGenerator trait instead? /// Post-execution filling of rest of adapter row. - fn fill_trace_row( - &self, - mem_helper: &MemoryAuxColsFactory, - ctx: Self::TraceContext<'_>, - adapter_row: &mut [F], - ); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, adapter_row: &mut [F]); } pub trait AdapterExecutorE1 @@ -403,7 +420,7 @@ pub trait StepExecutorE1 { const DEFAULT_RECORDS_CAPACITY: usize = 1 << 20; -impl InsExecutorE1 for NewVmChipWrapper +impl InsExecutorE1 for NewVmChipWrapper where F: PrimeField32, S: StepExecutorE1, @@ -685,49 +702,6 @@ mod conversions { } } - // AdapterRuntimeContext: VecHeapAdapterInterface -> DynInterface - impl< - T, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - >, - > for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - >, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> VecHeapAdapterInterface impl< T, @@ -759,35 +733,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> VecHeapAdapterInterface - impl< - T, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> VecHeapTwoReadsAdapterInterface impl< T: Clone, @@ -819,95 +764,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> VecHeapAdapterInterface - impl< - T, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - VecHeapTwoReadsAdapterInterface< - T, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - - // AdapterRuntimeContext: BasicInterface -> VecHeapAdapterInterface - impl< - T, - PI, - const BASIC_NUM_READS: usize, - const BASIC_NUM_WRITES: usize, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface< - T, - PI, - BASIC_NUM_READS, - BASIC_NUM_WRITES, - READ_SIZE, - WRITE_SIZE, - >, - >, - > - for AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface< - T, - PI, - BASIC_NUM_READS, - BASIC_NUM_WRITES, - READ_SIZE, - WRITE_SIZE, - >, - >, - ) -> Self { - assert_eq!(BASIC_NUM_WRITES, BLOCKS_PER_WRITE); - let mut writes_it = ctx.writes.into_iter(); - let writes = from_fn(|_| writes_it.next().unwrap()); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - // AdapterAirContext: BasicInterface -> VecHeapAdapterInterface impl< T, @@ -1062,79 +918,6 @@ mod conversions { } } - // AdapterRuntimeContext: BasicInterface -> FlatInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - const READ_CELLS: usize, - const WRITE_CELLS: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - > for AdapterRuntimeContext> - { - /// ## Panics - /// If `WRITE_CELLS != NUM_WRITES * WRITE_SIZE`. - /// This is a runtime assertion until Rust const generics expressions are stabilized. - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - ) -> AdapterRuntimeContext> { - assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE); - let mut writes_it = ctx.writes.into_iter().flatten(); - let writes = from_fn(|_| writes_it.next().unwrap()); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - - // AdapterRuntimeContext: FlatInterface -> BasicInterface - impl< - T: FieldAlgebra, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - const READ_CELLS: usize, - const WRITE_CELLS: usize, - > From>> - for AdapterRuntimeContext< - T, - BasicAdapterInterface, - > - { - /// ## Panics - /// If `WRITE_CELLS != NUM_WRITES * WRITE_SIZE`. - /// This is a runtime assertion until Rust const generics expressions are stabilized. - fn from( - ctx: AdapterRuntimeContext>, - ) -> AdapterRuntimeContext< - T, - BasicAdapterInterface, - > { - assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE); - let mut writes_it = ctx.writes.into_iter(); - let writes: [[T; WRITE_SIZE]; NUM_WRITES] = - from_fn(|_| from_fn(|_| writes_it.next().unwrap())); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - impl From> for DynArray { fn from(v: Vec) -> Self { Self(v) @@ -1246,35 +1029,6 @@ mod conversions { } } - // AdapterRuntimeContext: BasicInterface -> DynInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - > for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> BasicInterface impl< T, @@ -1301,28 +1055,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> BasicInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - BasicAdapterInterface, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: FlatInterface -> DynInterface impl>, const READ_CELLS: usize, const WRITE_CELLS: usize> From>> @@ -1338,21 +1070,6 @@ mod conversions { } } - // AdapterRuntimeContext: FlatInterface -> DynInterface - impl - From>> - for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext>, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.to_vec().into(), - } - } - } - impl From> for DynArray { fn from(m: MinimalInstruction) -> Self { Self(vec![m.is_valid, m.opcode]) diff --git a/crates/vm/src/system/memory/controller/mod.rs b/crates/vm/src/system/memory/controller/mod.rs index 28e0254c33..2c13147131 100644 --- a/crates/vm/src/system/memory/controller/mod.rs +++ b/crates/vm/src/system/memory/controller/mod.rs @@ -786,9 +786,11 @@ pub struct MemoryAuxColsFactory<'a, T> { // parallelized trace generation. impl MemoryAuxColsFactory<'_, F> { /// Fill the trace assuming `prev_timestamp` is already provided in `buffer`. - pub fn fill_from_prev(&self, timestamp: u32, buffer: &mut MemoryBaseAuxCols) { - let prev_timestamp = buffer.prev_timestamp.as_canonical_u32(); + pub fn fill(&self, prev_timestamp: u32, timestamp: u32, buffer: &mut MemoryBaseAuxCols) { self.generate_timestamp_lt(prev_timestamp, timestamp, &mut buffer.timestamp_lt_aux); + // Safety: even if prev_timestamp were obtained by transmute_ref from + // `buffer.prev_timestamp`, this should still work because it is a direct assignment + buffer.prev_timestamp = F::from_canonical_u32(prev_timestamp); } fn generate_timestamp_lt( @@ -806,18 +808,6 @@ impl MemoryAuxColsFactory<'_, F> { &mut buffer.lower_decomp, ); } - - fn generate_timestamp_lt_cols( - &self, - prev_timestamp: u32, - timestamp: u32, - ) -> LessThanAuxCols { - debug_assert!(prev_timestamp < timestamp); - let mut decomp = [F::ZERO; AUX_LEN]; - self.timestamp_lt_air - .generate_subrow((self.range_checker, prev_timestamp, timestamp), &mut decomp); - LessThanAuxCols::new(decomp) - } } impl SharedMemoryHelper { diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 677e04fd31..c4e4ee082f 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -1,8 +1,6 @@ //! Defines auxiliary columns for memory operations: `MemoryReadAuxCols`, //! `MemoryReadWithImmediateAuxCols`, and `MemoryWriteAuxCols`. -use std::ops::DerefMut; - use openvm_circuit_primitives::is_less_than::LessThanAuxCols; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::p3_field::PrimeField32; @@ -11,8 +9,8 @@ use crate::system::memory::offline_checker::bridge::AUX_LEN; // repr(C) is needed to make sure that the compiler does not reorder the fields // we assume the order of the fields when using borrow or borrow_mut -#[repr(C)] /// Base structure for auxiliary memory columns. +#[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryBaseAuxCols { /// The previous timestamps in which the cells were accessed. @@ -61,9 +59,8 @@ impl MemoryWriteAuxCols { &self.prev_data } - /// Sets the previous timestamp and data **without** updating the less than auxiliary columns. - pub fn set_prev(&mut self, timestamp: T, data: [T; N]) { - self.base.prev_timestamp = timestamp; + /// Sets the previous data **without** updating the less than auxiliary columns. + pub fn set_prev_data(&mut self, data: [T; N]) { self.prev_data = data; } } diff --git a/crates/vm/src/system/memory/offline_checker/mod.rs b/crates/vm/src/system/memory/offline_checker/mod.rs index ac9f32dc18..e0712d8cd7 100644 --- a/crates/vm/src/system/memory/offline_checker/mod.rs +++ b/crates/vm/src/system/memory/offline_checker/mod.rs @@ -5,3 +5,10 @@ mod columns; pub use bridge::*; pub use bus::*; pub use columns::*; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; + +#[repr(C)] +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable)] +pub struct MemoryReadAuxRecord { + pub prev_timestamp: u32, +} diff --git a/crates/vm/src/system/native_adapter/mod.rs b/crates/vm/src/system/native_adapter/mod.rs index 9fdde040ce..baa8117351 100644 --- a/crates/vm/src/system/native_adapter/mod.rs +++ b/crates/vm/src/system/native_adapter/mod.rs @@ -5,8 +5,8 @@ use std::{ use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterInterface, + AdapterAirContext, BasicAdapterInterface, ExecutionBridge, ExecutionBus, ExecutionState, + MinimalInstruction, Result, VmAdapterAir, VmAdapterInterface, }, system::{ memory::{ @@ -195,16 +195,13 @@ impl AdapterTraceStep for Native where F: PrimeField32, { - const WIDTH: usize = size_of::>(); type ReadData = [[F; 1]; R]; type WriteData = [[F; 1]; W]; - type TraceContext<'a> = (); + type RecordMut<'a> = (); // TODO #[inline(always)] - fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { - let adapter_row: &mut NativeAdapterCols = adapter_row.borrow_mut(); - adapter_row.from_state.pc = F::from_canonical_u32(pc); - adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + fn start(pc: u32, memory: &TracingMemory, record: &mut ()) { + todo!() } #[inline(always)] @@ -212,7 +209,7 @@ where &self, memory: &mut TracingMemory, instruction: &Instruction, - adapter_row: &mut [F], + record: &mut (), ) -> Self::ReadData { todo!("Implement read operation"); } @@ -222,21 +219,11 @@ where &self, memory: &mut TracingMemory, instruction: &Instruction, - adapter_row: &mut [F], data: &Self::WriteData, + record: &mut (), ) { todo!("Implement write operation"); } - - #[inline(always)] - fn fill_trace_row( - &self, - mem_helper: &MemoryAuxColsFactory, - bitwise_lookup_chip: Self::TraceContext<'_>, - adapter_row: &mut [F], - ) { - todo!("Implement fill_trace_row operation"); - } } impl AdapterExecutorE1 for NativeAdapterStep diff --git a/crates/vm/src/system/public_values/core.rs b/crates/vm/src/system/public_values/core.rs index a573ae8e1e..1919be3eee 100644 --- a/crates/vm/src/system/public_values/core.rs +++ b/crates/vm/src/system/public_values/core.rs @@ -1,4 +1,4 @@ -use std::{borrow::BorrowMut, sync::Mutex}; +use std::sync::Mutex; use openvm_circuit_primitives::{encoder::Encoder, SubAir}; use openvm_instructions::{ @@ -11,21 +11,19 @@ use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::{AirBuilder, AirBuilderWithPublicValues, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, rap::BaseAirWithPublicValues, }; use serde::{Deserialize, Serialize}; use crate::{ arch::{ - AdapterAirContext, AdapterExecutorE1, AdapterRuntimeContext, AdapterTraceStep, - BasicAdapterInterface, MinimalInstruction, Result, StepExecutorE1, TraceStep, - VmAdapterInterface, VmCoreAir, VmCoreChip, VmStateMut, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, BasicAdapterInterface, EmptyLayout, + MatrixRecordArena, MinimalInstruction, RecordArena, Result, RowMajorMatrixArena, + StepExecutorE1, TraceFiller, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, }, system::{ - memory::{ - online::{GuestMemory, TracingMemory}, - MemoryAuxColsFactory, - }, + memory::online::{GuestMemory, TracingMemory}, public_values::columns::PublicValuesCoreColsView, }, }; @@ -150,9 +148,12 @@ where CTX, ReadData = [[F; 1]; 2], WriteData = [[F; 1]; 0], - TraceContext<'a> = (), + RecordMut<'a> = (), >, { + type RecordLayout = EmptyLayout; + type RecordMut<'a> = (); // TODO + fn get_opcode_name(&self, opcode: usize) -> String { format!( "{:?}", @@ -160,41 +161,41 @@ where ) } - fn execute( + fn execute<'buf, RA>( &mut self, state: VmStateMut, CTX>, instruction: &Instruction, - trace: &mut [F], - trace_offset: &mut usize, - width: usize, - ) -> Result<()> { - todo!("Implement execute function"); + arena: &mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + Self: 'buf, + { + todo!() } - fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { - todo!("Implement fill_trace_row function"); + // fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + // let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; - // let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + // self.adapter.fill_trace_row(mem_helper, (), adapter_row); - // self.adapter.fill_trace_row(mem_helper, (), adapter_row); + // let core_row: &mut PublicValuesCoreColsView<_, F> = core_row.borrow_mut(); - // let core_row: &mut PublicValuesCoreColsView<_, F> = core_row.borrow_mut(); + // // TODO(ayush): add this check + // // debug_assert_eq!(core_row.width(), BaseAir::::width(&self.air)); - // // TODO(ayush): add this check - // // debug_assert_eq!(core_row.width(), BaseAir::::width(&self.air)); + // core_row.is_valid = F::ONE; + // core_row.value = record.value; + // core_row.index = record.index; - // core_row.is_valid = F::ONE; - // core_row.value = record.value; - // core_row.index = record.index; + // let idx: usize = record.index.as_canonical_u32() as usize; - // let idx: usize = record.index.as_canonical_u32() as usize; + // let pt = self.air.encoder.get_flag_pt(idx); - // let pt = self.air.encoder.get_flag_pt(idx); - - // for (i, var) in core_row.custom_pv_vars.iter_mut().enumerate() { - // *var = F::from_canonical_u32(pt[i]); - // } - } + // for (i, var) in core_row.custom_pv_vars.iter_mut().enumerate() { + // *var = F::from_canonical_u32(pt[i]); + // } + // } fn generate_public_values(&self) -> Vec { self.get_custom_public_values() @@ -204,6 +205,27 @@ where } } +impl TraceFiller for PublicValuesStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = [[F; 1]; 2], + WriteData = [[F; 1]; 0], + RecordMut<'a> = (), + >, +{ + fn fill_trace_row( + &self, + mem_helper: &crate::system::memory::MemoryAuxColsFactory, + row_slice: &mut [F], + ) { + todo!() + } +} + impl StepExecutorE1 for PublicValuesStep where F: PrimeField32, @@ -235,6 +257,36 @@ where } } +pub struct PublicValuesRecordArena { + inner: MatrixRecordArena, +} + +impl<'a, F: PrimeField32> RecordArena<'a, EmptyLayout, ()> for PublicValuesRecordArena { + fn alloc(&'a mut self, layout: EmptyLayout) -> () { + todo!() + } +} + +impl RowMajorMatrixArena for PublicValuesRecordArena { + fn with_capacity(height: usize, width: usize) -> Self { + Self { + inner: MatrixRecordArena::with_capacity(height, width), + } + } + + fn width(&self) -> usize { + self.inner.width() + } + + fn trace_offset(&self) -> usize { + self.inner.trace_offset() + } + + fn into_matrix(self) -> RowMajorMatrix { + self.inner.into_matrix() + } +} + // /// ATTENTION: If a specific public value is not provided, a default 0 will be used when // generating /// the proof but in the perspective of constraints, it could be any value. // pub struct PublicValuesCoreChip { diff --git a/crates/vm/src/system/public_values/mod.rs b/crates/vm/src/system/public_values/mod.rs index 8b849ab819..7ca307ea06 100644 --- a/crates/vm/src/system/public_values/mod.rs +++ b/crates/vm/src/system/public_values/mod.rs @@ -1,7 +1,7 @@ -use core::PublicValuesStep; +use core::{PublicValuesRecordArena, PublicValuesStep}; use crate::{ - arch::{NewVmChipWrapper, VmAirWrapper}, + arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}, system::{ native_adapter::{NativeAdapterAir, NativeAdapterStep}, public_values::core::PublicValuesCoreAir, @@ -17,4 +17,9 @@ mod tests; pub type PublicValuesAir = VmAirWrapper, PublicValuesCoreAir>; pub type PublicValuesStepWithAdapter = PublicValuesStep, F>; -pub type PublicValuesChip = NewVmChipWrapper>; +pub type PublicValuesChip = NewVmChipWrapper< + F, + PublicValuesAir, + PublicValuesStepWithAdapter, + PublicValuesRecordArena, +>; diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 0e28dd3093..7e19bf4b37 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -21,6 +21,7 @@ derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true +zerocopy = { workspace = true, features = ["derive"] } # for div_rem: num-bigint.workspace = true diff --git a/extensions/rv32im/circuit/src/adapters/alu.rs b/extensions/rv32im/circuit/src/adapters/alu.rs index a4b930f919..1778203f39 100644 --- a/extensions/rv32im/circuit/src/adapters/alu.rs +++ b/extensions/rv32im/circuit/src/adapters/alu.rs @@ -1,12 +1,17 @@ -use std::borrow::{Borrow, BorrowMut}; +use std::{ + borrow::{Borrow, BorrowMut}, + ptr::slice_from_raw_parts, +}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, BasicAdapterInterface, - ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + AdapterAirContext, AdapterExecutorE1, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, system::memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + }, online::{GuestMemory, TracingMemory}, MemoryAddress, MemoryAuxColsFactory, }, @@ -14,6 +19,7 @@ use openvm_circuit::{ use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -26,9 +32,11 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; use super::{ - tracing_read, tracing_read_imm, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + tracing_read, tracing_read_imm, tracing_write, Rv32WordWriteAuxRecord, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, }; use crate::adapters::{memory_read, memory_write}; @@ -166,27 +174,47 @@ pub struct Rv32BaseAluAdapterStep { pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } +// Intermediate type that should not be copied or cloned and should be directly written to +#[repr(C)] +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable)] +pub struct Rv32BaseAluAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + // Pack u8 together for alignment + pub rd_ptr: u8, + pub rs1_ptr: u8, + /// 1 if rs2 was a read, 0 if an immediate + pub rs2_as: u8, + pub _padding: u8, + + /// Pointer if rs2 was a read, immediate value otherwise + pub rs2: u32, + + pub reads_aux: [MemoryReadAuxRecord; 2], + pub writes_aux: Rv32WordWriteAuxRecord, +} + impl AdapterTraceStep for Rv32BaseAluAdapterStep { - const WIDTH: usize = size_of::>(); type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; - type TraceContext<'a> = (); + type RecordMut<'a> = &'a mut Rv32BaseAluAdapterRecord; #[inline(always)] - fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { - let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); - adapter_row.from_state.pc = F::from_canonical_u32(pc); - adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BaseAluAdapterRecord) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } + // @dev cannot get rid of double &mut due to trait #[inline(always)] fn read( &self, memory: &mut TracingMemory, instruction: &Instruction, - adapter_row: &mut [F], + record: &mut &mut Rv32BaseAluAdapterRecord, ) -> Self::ReadData { let &Instruction { b, c, d, e, .. } = instruction; @@ -195,30 +223,28 @@ impl AdapterTraceStep e.as_canonical_u32() == RV32_REGISTER_AS || e.as_canonical_u32() == RV32_IMM_AS ); - let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); - - adapter_row.rs1_ptr = b; + record.rs1_ptr = b.as_canonical_u32() as u8; let rs1 = tracing_read( memory, RV32_REGISTER_AS, - b.as_canonical_u32(), - &mut adapter_row.reads_aux[0], + record.rs1_ptr as u32, + &mut record.reads_aux[0].prev_timestamp, ); let rs2 = if e.as_canonical_u32() == RV32_REGISTER_AS { - adapter_row.rs2_as = e; - adapter_row.rs2 = c; + record.rs2_as = RV32_REGISTER_AS as u8; + record.rs2 = c.as_canonical_u32(); tracing_read( memory, RV32_REGISTER_AS, - c.as_canonical_u32(), - &mut adapter_row.reads_aux[1], + record.rs2, + &mut record.reads_aux[1].prev_timestamp, ) } else { - adapter_row.rs2_as = e; + record.rs2_as = RV32_IMM_AS as u8; - tracing_read_imm(memory, c.as_canonical_u32(), &mut adapter_row.rs2) + tracing_read_imm(memory, c.as_canonical_u32(), &mut record.rs2) }; [rs1, rs2] @@ -229,50 +255,86 @@ impl AdapterTraceStep &self, memory: &mut TracingMemory, instruction: &Instruction, - adapter_row: &mut [F], data: &Self::WriteData, + record: &mut &mut Rv32BaseAluAdapterRecord, ) { let &Instruction { a, d, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); - - adapter_row.rd_ptr = a; + record.rd_ptr = a.as_canonical_u32() as u8; tracing_write( memory, RV32_REGISTER_AS, - a.as_canonical_u32(), + record.rd_ptr as u32, &data[0], - &mut adapter_row.writes_aux, + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, ); } +} - #[inline(always)] - fn fill_trace_row( - &self, - mem_helper: &MemoryAuxColsFactory, - _ctx: (), - adapter_row: &mut [F], - ) { - let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); - - let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); - - mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[0].as_mut()); - timestamp += 1; +impl AdapterTraceFiller + for Rv32BaseAluAdapterStep +{ + const WIDTH: usize = size_of::>(); - if !adapter_row.rs2_as.is_zero() { - mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[1].as_mut()); - } else { - let rs2_imm = adapter_row.rs2.as_canonical_u32(); - let mask = (1 << RV32_CELL_BITS) - 1; - self.bitwise_lookup_chip - .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); + // SAFETY: the following is highly unsafe. We are going to cast `adapter_row` to a record + // buffer, and then do an _overlapping_ write to the `adapter_row` as a row of field + // elements. This requires: + // - Cols struct should be repr(C) and we write in reverse order (to ensure non-overlapping) + // - Do not overwrite any reference in `record` before it has already been used or moved + // - alignment of `F` must be >= alignment of Record (zerocopy will panic otherwise) + unsafe { + let ptr = adapter_row as *mut _ as *mut u8; + let record_buffer = &*slice_from_raw_parts(ptr, size_of::()); + let (record, _) = Rv32BaseAluAdapterRecord::ref_from_prefix(record_buffer).unwrap(); + // We must assign in reverse + // TODO[jpw]: is there a way to not hardcode? + const TIMESTAMP_DELTA: u32 = 2; + let mut timestamp = record.from_timestamp + TIMESTAMP_DELTA; + + adapter_row + .writes_aux + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + timestamp, + adapter_row.writes_aux.as_mut(), + ); + timestamp -= 1; + + let rs2_as = record.rs2_as; + if rs2_as != 0 { + mem_helper.fill( + record.reads_aux[1].prev_timestamp, + timestamp, + adapter_row.reads_aux[1].as_mut(), + ); + } else { + let rs2_imm = adapter_row.rs2.as_canonical_u32(); + let mask = (1 << RV32_CELL_BITS) - 1; + self.bitwise_lookup_chip + .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask); + } + timestamp -= 1; + + mem_helper.fill( + record.reads_aux[0].prev_timestamp, + timestamp, + adapter_row.reads_aux[0].as_mut(), + ); + + // Write to rs2 first just in case since it appears later in Record + adapter_row.rs2 = F::from_canonical_u32(record.rs2); + adapter_row.rs2_as = F::from_canonical_u8(rs2_as); + adapter_row.rs1_ptr = F::from_canonical_u8(record.rs1_ptr); + adapter_row.rd_ptr = F::from_canonical_u8(record.rd_ptr); + adapter_row.from_state.timestamp = F::from_canonical_u32(timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } - timestamp += 1; - - mem_helper.fill_from_prev(timestamp, adapter_row.writes_aux.as_mut()); } } diff --git a/extensions/rv32im/circuit/src/adapters/mod.rs b/extensions/rv32im/circuit/src/adapters/mod.rs index ba458930df..019cdbaffd 100644 --- a/extensions/rv32im/circuit/src/adapters/mod.rs +++ b/extensions/rv32im/circuit/src/adapters/mod.rs @@ -1,7 +1,9 @@ use std::ops::Mul; use openvm_circuit::system::memory::{ - offline_checker::{MemoryBaseAuxCols, MemoryReadAuxCols, MemoryWriteAuxCols}, + offline_checker::{ + MemoryBaseAuxCols, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + }, online::{GuestMemory, TracingMemory}, tree::public_values::PUBLIC_VALUES_AS, MemoryController, RecordId, @@ -23,6 +25,7 @@ pub use loadstore::*; pub use mul::*; pub use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; pub use rdwrite::*; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; /// 256-bit heap integer stored as 32 bytes (32 limbs of 8-bits) pub const INT256_NUM_LIMBS: usize = 32; @@ -35,6 +38,13 @@ pub const RV_B_TYPE_IMM_BITS: usize = 13; pub const RV_J_TYPE_IMM_BITS: usize = 21; +#[repr(C)] +#[derive(FromBytes, IntoBytes, KnownLayout, Immutable)] +pub struct Rv32WordWriteAuxRecord { + pub prev_timestamp: u32, + pub prev_data: [u8; RV32_REGISTER_NUM_LIMBS], +} + /// Convert the RISC-V register data (32 bits represented as 4 bytes, where each byte is represented /// as a field element) back into its value as u32. pub fn compose(ptr_data: [F; RV32_REGISTER_NUM_LIMBS]) -> u32 { @@ -135,14 +145,13 @@ pub fn tracing_read( memory: &mut TracingMemory, address_space: u32, ptr: u32, - aux_cols: &mut MemoryReadAuxCols, /* TODO[jpw]: switch to raw u8 - * buffer */ + prev_timestamp: &mut u32, ) -> [u8; N] where F: PrimeField32, { let (t_prev, data) = timed_read(memory, address_space, ptr); - aux_cols.set_prev(F::from_canonical_u32(t_prev)); + *prev_timestamp = t_prev; data } @@ -154,17 +163,14 @@ pub fn tracing_write( address_space: u32, ptr: u32, data: &[u8; N], - aux_cols: &mut MemoryWriteAuxCols, /* TODO[jpw]: switch to raw - * u8 - * buffer */ + prev_timestamp: &mut u32, + prev_data: &mut [u8; N], ) where F: PrimeField32, { let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data); - aux_cols.set_prev( - F::from_canonical_u32(t_prev), - data_prev.map(F::from_canonical_u8), - ); + *prev_timestamp = t_prev; + *prev_data = data_prev; } // TODO(ayush): this is bad but not sure how to avoid @@ -186,12 +192,12 @@ pub fn tracing_write_with_base_aux( pub fn tracing_read_imm( memory: &mut TracingMemory, imm: u32, - imm_mut: &mut F, + imm_mut: &mut u32, ) -> [u8; RV32_REGISTER_NUM_LIMBS] where F: PrimeField32, { - *imm_mut = F::from_canonical_u32(imm); + *imm_mut = imm; debug_assert_eq!(imm >> 24, 0); // highest byte should be zero to prevent overflow memory.increment_timestamp(); diff --git a/extensions/rv32im/circuit/src/auipc/core.rs b/extensions/rv32im/circuit/src/auipc/core.rs index 1e746ccbe8..60bac7e0a3 100644 --- a/extensions/rv32im/circuit/src/auipc/core.rs +++ b/extensions/rv32im/circuit/src/auipc/core.rs @@ -29,7 +29,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; use crate::adapters::{Rv32RdWriteAdapterCols, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; @@ -239,7 +238,7 @@ where let local_opcode = Rv32AuipcOpcode::from_usize(opcode.local_opcode_idx(Rv32AuipcOpcode::CLASS_OFFSET)); - let mut row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; A::start(*state.pc, state.memory, adapter_row); diff --git a/extensions/rv32im/circuit/src/base_alu/core.rs b/extensions/rv32im/circuit/src/base_alu/core.rs index a653a69da3..e48d700014 100644 --- a/extensions/rv32im/circuit/src/base_alu/core.rs +++ b/extensions/rv32im/circuit/src/base_alu/core.rs @@ -2,12 +2,14 @@ use std::{ array, borrow::{Borrow, BorrowMut}, iter::zip, + ptr::slice_from_raw_parts, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, - StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + AdapterAirContext, AdapterExecutorE1, AdapterTraceFiller, AdapterTraceStep, EmptyLayout, + MatrixRecordArena, MinimalInstruction, RecordArena, Result, RowMajorMatrixArena, + StepExecutorE1, TraceFiller, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, }, system::memory::{ online::{GuestMemory, TracingMemory}, @@ -17,17 +19,24 @@ use openvm_circuit::{ use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, +}; use openvm_rv32im_transpiler::BaseAluOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, rap::BaseAirWithPublicValues, }; use strum::IntoEnumIterator; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; + +use crate::adapters::{Rv32BaseAluAdapterCols, Rv32BaseAluAdapterRecord}; #[repr(C)] #[derive(AlignedBorrow)] @@ -190,6 +199,16 @@ impl BaseAluStep { + pub a: [u8; NUM_LIMBS], + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + // Use u8 instead of usize for better packing + pub local_opcode: u8, +} + impl TraceStep for BaseAluStep where @@ -200,78 +219,162 @@ where CTX, ReadData: Into<[[u8; NUM_LIMBS]; 2]>, WriteData: From<[[u8; NUM_LIMBS]; 1]>, - TraceContext<'a> = (), >, { + /// Instructions that use one trace row per instruction have implicit layout + type RecordLayout = EmptyLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut BaseAluCoreRecord); + fn get_opcode_name(&self, opcode: usize) -> String { format!("{:?}", BaseAluOpcode::from_usize(opcode - self.offset)) } - fn execute( + fn execute<'buf, RA>( &mut self, state: VmStateMut, CTX>, instruction: &Instruction, - trace: &mut [F], - trace_offset: &mut usize, - width: usize, - ) -> Result<()> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, EmptyLayout, Self::RecordMut<'buf>>, + { let Instruction { opcode, .. } = instruction; let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let row_slice = &mut trace[*trace_offset..*trace_offset + width]; - let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + let (mut adapter_record, core_record) = arena.alloc(EmptyLayout); - A::start(*state.pc, state.memory, adapter_row); + A::start(*state.pc, state.memory, &mut adapter_record); let [rs1, rs2] = self .adapter - .read(state.memory, instruction, adapter_row) + .read(state.memory, instruction, &mut adapter_record) .into(); let rd = run_alu::(local_opcode, &rs1, &rs2); - let core_row: &mut BaseAluCoreCols = core_row.borrow_mut(); - core_row.a = rd.map(F::from_canonical_u8); - core_row.b = rs1.map(F::from_canonical_u8); - core_row.c = rs2.map(F::from_canonical_u8); - core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD); - core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB); - core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR); - core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR); - core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND); + core_record.a = rd; + core_record.b = rs1; + core_record.c = rs2; + core_record.local_opcode = local_opcode as u8; self.adapter - .write(state.memory, instruction, adapter_row, &[rd].into()); + .write(state.memory, instruction, &[rd].into(), &mut adapter_record); *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - *trace_offset += width; - Ok(()) } +} +impl TraceFiller + for BaseAluStep +where + F: PrimeField32, + A: 'static + for<'a> AdapterTraceFiller, +{ fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; - - self.adapter.fill_trace_row(mem_helper, (), adapter_row); + self.adapter.fill_trace_row(mem_helper, adapter_row); let core_row: &mut BaseAluCoreCols = core_row.borrow_mut(); - - if core_row.opcode_add_flag == F::ONE || core_row.opcode_sub_flag == F::ONE { - for a_val in core_row.a.map(|x| x.as_canonical_u32()) { - self.bitwise_lookup_chip.request_xor(a_val, a_val); - } - } else { - let b = core_row.b.map(|x| x.as_canonical_u32()); - let c = core_row.c.map(|x| x.as_canonical_u32()); - for (b_val, c_val) in zip(b, c) { - self.bitwise_lookup_chip.request_xor(b_val, c_val); + // SAFETY: the following is highly unsafe. We are going to cast `core_row` to a record + // buffer, and then do an _overlapping_ write to the `core_row` as a row of field elements. + // This requires: + // - Cols and Record structs should be repr(C) and we write in reverse order (to ensure + // non-overlapping) + // - Do not overwrite any reference in `record` before it has already been used or moved + // - alignment of `F` must be >= alignment of Record (zerocopy will panic otherwise) + unsafe { + let ptr = core_row as *mut _ as *mut u8; + let record_buffer = + &*slice_from_raw_parts(ptr, size_of::>()); + let (record, _) = BaseAluCoreRecord::ref_from_prefix(record_buffer).unwrap(); + + // PERF: needless conversion + let local_opcode = BaseAluOpcode::from_usize(record.local_opcode as usize); + core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND); + core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR); + core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR); + core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB); + core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD); + + if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB { + for a_val in record.a { + self.bitwise_lookup_chip + .request_xor(a_val as u32, a_val as u32); + } + } else { + for (b_val, c_val) in zip(record.b, record.c) { + self.bitwise_lookup_chip + .request_xor(b_val as u32, c_val as u32); + } } + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = record.a.map(F::from_canonical_u8); } } } +pub struct Rv32BaseAluRecordArena { + inner: MatrixRecordArena, +} + +// NOTE[jpw]: this is an implementation only for RV32IM extension, not for bigint etc +impl<'a, F: PrimeField32> + RecordArena< + 'a, + EmptyLayout, + ( + &'a mut Rv32BaseAluAdapterRecord, + &'a mut BaseAluCoreRecord, + ), + > for Rv32BaseAluRecordArena +{ + fn alloc( + &'a mut self, + _: EmptyLayout, + ) -> ( + &'a mut Rv32BaseAluAdapterRecord, + &'a mut BaseAluCoreRecord, + ) { + let buffer = self.inner.alloc_single_row(); + // NOTE: the Cols type has generic because we want the size in bytes, not number of + // field elements + let (adapter_buffer, core_buffer) = + buffer.split_at_mut(size_of::>()); + // PERF: we could skip these unwraps if the RecordArena guarantees the size and alignment + // properties + let (adapter_record, _) = + Rv32BaseAluAdapterRecord::mut_from_prefix(adapter_buffer).unwrap(); + let (core_record, _) = BaseAluCoreRecord::mut_from_prefix(core_buffer).unwrap(); + + (adapter_record, core_record) + } +} + +// TODO: make a macro +impl RowMajorMatrixArena for Rv32BaseAluRecordArena { + fn with_capacity(height: usize, width: usize) -> Self { + Self { + inner: MatrixRecordArena::with_capacity(height, width), + } + } + + fn width(&self) -> usize { + self.inner.width() + } + + fn trace_offset(&self) -> usize { + self.inner.trace_offset() + } + + fn into_matrix(self) -> RowMajorMatrix { + self.inner.into_matrix() + } +} + impl StepExecutorE1 for BaseAluStep where diff --git a/extensions/rv32im/circuit/src/base_alu/mod.rs b/extensions/rv32im/circuit/src/base_alu/mod.rs index 266a7ee453..5c5b1e8781 100644 --- a/extensions/rv32im/circuit/src/base_alu/mod.rs +++ b/extensions/rv32im/circuit/src/base_alu/mod.rs @@ -14,4 +14,5 @@ pub type Rv32BaseAluAir = VmAirWrapper>; pub type Rv32BaseAluStep = BaseAluStep, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>; -pub type Rv32BaseAluChip = NewVmChipWrapper; +pub type Rv32BaseAluChip = + NewVmChipWrapper>;