diff --git a/.vscode/settings.json b/.vscode/settings.json index 988f817d02..d9a52c68d8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -27,6 +27,31 @@ "ios": "cpp", "iosfwd": "cpp", "vector": "cpp", - "charconv": "cpp" + "charconv": "cpp", + "__hash_table": "cpp", + "__split_buffer": "cpp", + "__tree": "cpp", + "array": "cpp", + "bitset": "cpp", + "deque": "cpp", + "initializer_list": "cpp", + "list": "cpp", + "map": "cpp", + "queue": "cpp", + "random": "cpp", + "regex": "cpp", + "set": "cpp", + "span": "cpp", + "stack": "cpp", + "string": "cpp", + "string_view": "cpp", + "unordered_map": "cpp", + "unordered_set": "cpp", + "valarray": "cpp", + "iterator": "cpp", + "utility": "cpp", + "rope": "cpp", + "slist": "cpp", + "ranges": "cpp" } } diff --git a/risc0/zkp/rust/src/core/sha.rs b/risc0/zkp/rust/src/core/sha.rs index c26d1f2318..c3e98baa43 100644 --- a/risc0/zkp/rust/src/core/sha.rs +++ b/risc0/zkp/rust/src/core/sha.rs @@ -77,6 +77,15 @@ impl Digest { &self.0 } + /// Returns as a slice of be u8 + pub fn get_u8(&self) -> [u8; DIGEST_WORDS * 4] { + let mut res: [u8; DIGEST_WORDS * 4] = [0; DIGEST_WORDS * 4]; + for i in 0..DIGEST_WORDS { + res[4 * i..][..4].copy_from_slice(&self.0[i].to_be_bytes()); + } + res + } + /// Returns a mutable slice of words. pub fn get_mut(&mut self) -> &mut [u32; DIGEST_WORDS] { &mut self.0 diff --git a/risc0/zkvm/platform/io.h b/risc0/zkvm/platform/io.h index 90a1e7aa4f..2f9925c524 100644 --- a/risc0/zkvm/platform/io.h +++ b/risc0/zkvm/platform/io.h @@ -26,6 +26,7 @@ constexpr size_t kGPIO_GetKey = 0x01F0010; constexpr size_t kGPIO_SendRecvChannel = 0x01F00014; constexpr size_t kGPIO_SendRecvSize = 0x01F00018; constexpr size_t kGPIO_SendRecvAddr = 0x01F0001C; +constexpr size_t kGPIO_Mul = 0x01F00020; // Standard ZKVM channels; must match zkvm/sdk/rust/platform/src/io.rs. @@ -35,6 +36,8 @@ constexpr uint32_t kSendRecvChannel_InitialInput = 0; constexpr uint32_t kSendRecvChannel_Stdout = 1; // Write bytes to standard error constexpr uint32_t kSendRecvChannel_Stderr = 2; +// Request aux tape to the guest +constexpr uint32_t kSendRecvChannel_InitialInputAux = 3; // To invoke accelerated SHA, the guest writes ShaDescriptor structs // in sequence to the "SHA" memory region. Once the ShaDescriptor has @@ -65,6 +68,15 @@ struct ShaDescriptor { uint32_t digest; }; +struct MulDescriptor { + // Address of first byte of MUL data to process + // 128 bits for first operand and 128 bits for second + uint32_t source; + + // 128 bit result + uint32_t result; +}; + inline volatile ShaDescriptor* volatile* GPIO_SHA() { return reinterpret_cast(kGPIO_SHA); } diff --git a/risc0/zkvm/platform/memory.h b/risc0/zkvm/platform/memory.h index 7b427a43f9..66723d37f3 100644 --- a/risc0/zkvm/platform/memory.h +++ b/risc0/zkvm/platform/memory.h @@ -43,8 +43,9 @@ MEM_REGION(Input, 0x01E00000, k1MB) MEM_REGION(GPIO, 0x01F00000, k1MB) MEM_REGION(Prog, 0x02000000, 10 * k1MB) MEM_REGION(SHA, 0x02A00000, k1MB) -MEM_REGION(WOM, 0x02B00000, 21 * k1MB) -MEM_REGION(Output, 0x02B00000, 20 * k1MB) +MEM_REGION(MUL, 0x02B00000, k1MB) +MEM_REGION(WOM, 0x02C00000, 20 * k1MB) +MEM_REGION(Output, 0x02C00000, 19 * k1MB) MEM_REGION(Commit, 0x03F00000, k1MB) // clang-format on diff --git a/risc0/zkvm/platform/risc0.ld b/risc0/zkvm/platform/risc0.ld index ed4e89e238..7b51400427 100644 --- a/risc0/zkvm/platform/risc0.ld +++ b/risc0/zkvm/platform/risc0.ld @@ -29,7 +29,8 @@ MEMORY { gpio : ORIGIN = 0x01F00000, LENGTH = 1M prog (X) : ORIGIN = 0x02000000, LENGTH = 10M sha : ORIGIN = 0x02A00000, LENGTH = 1M - wom : ORIGIN = 0x02B00000, LENGTH = 21M + mul : ORIGIN = 0x02B00000, LENGTH = 1M + wom : ORIGIN = 0x02C00000, LENGTH = 20M } SECTIONS { diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 68784408ef..7bdc9dd0ed 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -23,6 +23,82 @@ namespace risc0 { +class FpG { + // implement just enough operations to support extension field multiplication + // all values are in mont form +public: + static CONSTSCALAR uint64_t M = 0xFFFFFFFF00000001; + uint64_t val; + +private: + static DEVSPEC constexpr uint64_t add(uint64_t a, uint64_t b) { + bool c1 = (M - b) > a; + uint64_t x1 = a - (M - b); + uint32_t adj = uint32_t(0) - uint32_t(c1); + uint64_t res = x1 - uint64_t(adj); + // std::cout << "c1: " << c1 << ", x1: " << x1 << ", adj: " << adj << ", res: " << res << + // std::endl; + return res; + } + + static DEVSPEC constexpr uint64_t sub(uint64_t a, uint64_t b) { + bool c1 = b > a; + uint64_t x1 = a - b; + uint32_t adj = 0 - uint32_t(c1); + return x1 - uint64_t(adj); + } + + static DEVSPEC constexpr uint64_t doubleVal(uint64_t a) { + __uint128_t ret = __uint128_t(a) << 1; + uint64_t result = uint64_t(ret); + uint64_t over = uint64_t(ret >> 64); + return result - (M * over); + } + + static DEVSPEC constexpr uint64_t montRedCst(__uint128_t n) { + uint64_t xl = uint64_t(n); + uint64_t xh = uint64_t(n >> 64); + bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; // overflow + uint64_t a = xl + (xl << 32); + uint64_t b = a - (a >> 32) - e; + bool c = xh < b; + uint64_t r = xh - b; + uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); + return mont_result; + } + + static DEVSPEC constexpr uint64_t mul(uint64_t a, uint64_t b) { + __uint128_t n = __uint128_t(a) * __uint128_t(b); + return montRedCst(n); + } + +public: + DEVSPEC constexpr FpG(uint64_t val) : val(val) {} + DEVSPEC constexpr FpG operator+(FpG rhs) const { return FpG(add(val, rhs.val)); } + DEVSPEC constexpr FpG operator-(FpG rhs) const { return FpG(sub(val, rhs.val)); } + DEVSPEC constexpr FpG operator*(FpG rhs) const { return FpG(mul(val, rhs.val)); } + DEVSPEC constexpr FpG doubleVal() const { return FpG(doubleVal(val)); } +}; + +static std::pair extensionMul(std::pair a, std::pair b) { + FpG a0b0 = a.first * b.first; + FpG a1b1 = a.second * b.second; + FpG first = a0b0 - a1b1.doubleVal(); + + FpG a0a1 = a.first + a.second; + FpG b0b1 = b.first + b.second; + FpG second = a0a1 * b0b1 - a0b0; + + // std::cout << "CPP a: [" << a.first.val << ", " << a.second.val << "]" << std::endl; + // std::cout << "b: [" << b.first.val << ", " << b.second.val << "]" << std::endl; + + // std::cout << "a0b0: " << a0b0.val << ", a1b1: " << a1b1.val << ", first: " << first.val + // << ", a0a1: " << a0a1.val << ", b0b1: " << b0b1.val << ", second: " + // << second.val << std::endl; + + return std::pair(first, second); +} + static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { uint16_t type = (desc.typeAndCount & 0xFFFF) >> 4; uint16_t count = desc.typeAndCount & 0xFFFF; @@ -46,6 +122,52 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { } } +static void processMul(MemoryState& mem, const MulDescriptor& desc) { + uint32_t a0_hi = mem.load(desc.source); + LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a0_hi)); + uint32_t a0_lo = mem.load(desc.source + 4); + LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a0_lo)); + uint32_t a1_hi = mem.load(desc.source + 8); + LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(a1_hi)); + uint32_t a1_lo = mem.load(desc.source + 12); + LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(a1_lo)); + + uint32_t b0_hi = mem.load(desc.source + 16); + LOG(1, "Input[" << hex(4, 2) << "]: " << hex(desc.source + 16) << " -> " << hex(b0_hi)); + uint32_t b0_lo = mem.load(desc.source + 20); + LOG(1, "Input[" << hex(5, 2) << "]: " << hex(desc.source + 20) << " -> " << hex(b0_lo)); + uint32_t b1_hi = mem.load(desc.source + 24); + LOG(1, "Input[" << hex(6, 2) << "]: " << hex(desc.source + 24) << " -> " << hex(b1_hi)); + uint32_t b1_lo = mem.load(desc.source + 28); + LOG(1, "Input[" << hex(7, 2) << "]: " << hex(desc.source + 28) << " -> " << hex(b1_lo)); + + uint64_t a0 = a0_lo | (uint64_t(a0_hi) << 32); + uint64_t a1 = a1_lo | (uint64_t(a1_hi) << 32); + uint64_t b0 = b0_lo | (uint64_t(b0_hi) << 32); + uint64_t b1 = b1_lo | (uint64_t(b1_hi) << 32); + + std::pair a = std::pair(FpG(a0), FpG(a1)); + std::pair b = std::pair(FpG(b0), FpG(b1)); + std::pair result = extensionMul(a, b); + + uint64_t r0 = result.first.val; + uint32_t r0_high = (uint32_t)((r0 & 0xFFFFFFFF00000000LL) >> 32); + uint32_t r0_low = (uint32_t)(r0 & 0xFFFFFFFFLL); + + uint64_t r1 = result.second.val; + uint32_t r1_high = (uint32_t)((r1 & 0xFFFFFFFF00000000LL) >> 32); + uint32_t r1_low = (uint32_t)(r1 & 0xFFFFFFFFLL); + + LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(r0_high)); + mem.store(desc.result, r0_high); + LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(r0_low)); + mem.store(desc.result + 4, r0_low); + LOG(1, "Output[" << hex(2, 2) << "]: " << hex(desc.result + 8) << " <- " << hex(r1_high)); + mem.store(desc.result + 8, r1_high); + LOG(1, "Output[" << hex(3, 2) << "]: " << hex(desc.result + 12) << " <- " << hex(r1_low)); + mem.store(desc.result + 12, r1_low); +} + void IoHandler::onFault(const std::string& msg) { throw std::runtime_error(msg); } @@ -63,6 +185,13 @@ void MemoryHandler::onInit(MemoryState& mem) { void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) { LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value)); switch (addr) { + case kGPIO_Mul: { + LOG(1, "MemoryHandler::onWrite> GPIO_MUL"); + MulDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + processMul(mem, desc); + break; + } case kGPIO_SHA: { LOG(1, "MemoryHandler::onWrite> GPIO_SHA"); ShaDescriptor desc; diff --git a/risc0/zkvm/sdk/cpp/host/c_api.cpp b/risc0/zkvm/sdk/cpp/host/c_api.cpp index 638dc54125..8f5de9fa43 100644 --- a/risc0/zkvm/sdk/cpp/host/c_api.cpp +++ b/risc0/zkvm/sdk/cpp/host/c_api.cpp @@ -129,6 +129,13 @@ void risc0_prover_add_input(risc0_error* err, risc0_prover* ptr, const uint8_t* ffi_wrap_void(err, [&] { ptr->prover->writeInput(buf, len); }); } +void risc0_prover_add_aux_input(risc0_error* err, + risc0_prover* ptr, + const uint8_t* buf, + size_t len) { + ffi_wrap_void(err, [&] { ptr->prover->writeInputAux(buf, len); }); +} + const void* risc0_prover_get_output_buf(risc0_error* err, const risc0_prover* ptr) { return ffi_wrap(err, nullptr, [&] { return ptr->prover->getOutput().data(); }); } diff --git a/risc0/zkvm/sdk/cpp/host/c_api.h b/risc0/zkvm/sdk/cpp/host/c_api.h index 20172ebdee..b9f851078d 100644 --- a/risc0/zkvm/sdk/cpp/host/c_api.h +++ b/risc0/zkvm/sdk/cpp/host/c_api.h @@ -89,6 +89,11 @@ void risc0_prover_free(risc0_error* err, risc0_prover* ptr); void risc0_prover_add_input(risc0_error* err, risc0_prover* ptr, const uint8_t* buf, size_t len); +void risc0_prover_add_aux_input(risc0_error* err, + risc0_prover* ptr, + const uint8_t* buf, + size_t len); + size_t risc0_prover_get_num_outputs(risc0_error* err, risc0_prover* ptr); const void* risc0_prover_get_output_buf(risc0_error* err, const risc0_prover* ptr); diff --git a/risc0/zkvm/sdk/cpp/host/receipt.cpp b/risc0/zkvm/sdk/cpp/host/receipt.cpp index fb7d3eaf9d..732ab3e203 100644 --- a/risc0/zkvm/sdk/cpp/host/receipt.cpp +++ b/risc0/zkvm/sdk/cpp/host/receipt.cpp @@ -63,6 +63,7 @@ struct Prover::Impl : public IoHandler { , outputStream(outputBuffer) , commitStream(commitBuffer) , inputWriter(inputStream) + , inputWriterAux(inputStreamAux) , outputReader(outputStream) , commitReader(commitStream) { // Set default handlers: @@ -83,6 +84,13 @@ struct Prover::Impl : public IoHandler { LOG(1, "IoHandler::InitialInput, " << input.size() << " bytes"); return input; }); + setSendRecvHandler( + kSendRecvChannel_InitialInputAux, [this](uint32_t, const BufferU8& buf) -> BufferU8 { + const uint8_t* byte_ptr = reinterpret_cast(inputStreamAux.vec.data()); + BufferU8 input(byte_ptr, byte_ptr + inputStreamAux.vec.size() * sizeof(uint32_t)); + LOG(1, "IoHandler::InitialInputAux, " << input.size() << " bytes"); + return input; + }); } virtual ~Impl() {} @@ -116,9 +124,11 @@ struct Prover::Impl : public IoHandler { BufferU8 outputBuffer; BufferU8 commitBuffer; VectorStreamWriter inputStream; + VectorStreamWriter inputStreamAux; CheckedStreamReader outputStream; CheckedStreamReader commitStream; ArchiveWriter inputWriter; + ArchiveWriter inputWriterAux; ArchiveReader outputReader; ArchiveReader commitReader; @@ -219,6 +229,31 @@ void Prover::writeInput(const void* ptr, size_t size) { } } +void Prover::writeInputAux(const void* ptr, size_t size) { + LOG(1, "Prover::writeInputAux> size: " << size); + const uint8_t* ptr_u8 = static_cast(ptr); + while (size >= sizeof(uint32_t)) { + uint32_t word = 0; + word |= *ptr_u8++; + word |= *ptr_u8++ << 8; + word |= *ptr_u8++ << 16; + word |= *ptr_u8++ << 24; + LOG(1, " write_word: " << hex(word)); + impl->inputStreamAux.write_word(word); + size -= sizeof(uint32_t); + } + + if (size) { + LOG(1, " tail: " << size); + uint32_t word = 0; + for (size_t i = 0; i < size; i++) { + word |= *ptr_u8++ << (8 * i); + } + LOG(1, " write_word: " << hex(word)); + impl->inputStreamAux.write_word(word); + } +} + void Prover::setSendRecvHandler( uint32_t channelId, const std::function& handler) { diff --git a/risc0/zkvm/sdk/cpp/host/receipt.h b/risc0/zkvm/sdk/cpp/host/receipt.h index 9e419935fb..ebb7dadb2f 100644 --- a/risc0/zkvm/sdk/cpp/host/receipt.h +++ b/risc0/zkvm/sdk/cpp/host/receipt.h @@ -87,6 +87,8 @@ class Prover { void writeInput(const void* ptr, size_t size); + void writeInputAux(const void* ptr, size_t size); + template void writeInput(const T& obj) { getInputWriter().transfer(obj); } const BufferU8& getOutput(); diff --git a/risc0/zkvm/sdk/rust/guest/src/env.rs b/risc0/zkvm/sdk/rust/guest/src/env.rs index 2e54466da5..896eb7d0b7 100644 --- a/risc0/zkvm/sdk/rust/guest/src/env.rs +++ b/risc0/zkvm/sdk/rust/guest/src/env.rs @@ -17,7 +17,10 @@ use core::{cell::UnsafeCell, mem::MaybeUninit, slice}; use risc0_zkp::core::sha::Digest; use risc0_zkvm::{ platform::{ - io::{IoDescriptor, GPIO_COMMIT, SENDRECV_CHANNEL_INITIAL_INPUT, SENDRECV_CHANNEL_STDOUT}, + io::{ + IoDescriptor, GPIO_COMMIT, GPIO_LOG, SENDRECV_CHANNEL_INITIAL_AUX_INPUT, + SENDRECV_CHANNEL_INITIAL_INPUT, SENDRECV_CHANNEL_STDOUT, + }, memory, WORD_SIZE, }, serde::{Deserializer, Serializer, Slice}, @@ -98,6 +101,11 @@ pub fn read>() -> T { ENV.get().read() } +/// Read private raw data from the host. +pub fn read_aux_input() -> &'static [u8] { + ENV.get().read_aux_input() +} + /// Write private data to the host. pub fn write(data: &T) { ENV.get().write(data); @@ -108,6 +116,15 @@ pub fn commit(data: &T) { ENV.get().commit(data); } +/// Print a message to the debug console. +pub fn log(msg: &str) { + // TODO: format! is expensive, replace with a better solution. + let msg = alloc_crate::format!("{}\0", msg); + let ptr = msg.as_ptr(); + memory_barrier(ptr); + unsafe { GPIO_LOG.as_ptr().write_volatile(ptr) }; +} + impl Env { fn new() -> Self { Env { @@ -140,6 +157,10 @@ impl Env { self.initial_input_reader.as_mut().unwrap() } + pub fn read_aux_input(&mut self) -> &[u8] { + self.send_recv(SENDRECV_CHANNEL_INITIAL_AUX_INPUT, &[]) + } + pub fn read>(&mut self) -> T { self.initial_input().read() } diff --git a/risc0/zkvm/sdk/rust/guest/src/lib.rs b/risc0/zkvm/sdk/rust/guest/src/lib.rs index d12df0a88f..094fc776d4 100644 --- a/risc0/zkvm/sdk/rust/guest/src/lib.rs +++ b/risc0/zkvm/sdk/rust/guest/src/lib.rs @@ -21,6 +21,7 @@ #![cfg_attr(target_arch = "riscv32", feature(new_uninit))] extern crate alloc as _alloc; +pub extern crate alloc as alloc_crate; #[cfg(not(feature = "std"))] mod alloc; @@ -31,6 +32,9 @@ pub mod env; /// Functions for computing SHA-256 hashes. pub mod sha; +/// mul +pub mod mul; + /// Functions for handling input and output pub mod io; diff --git a/risc0/zkvm/sdk/rust/guest/src/mul.rs b/risc0/zkvm/sdk/rust/guest/src/mul.rs new file mode 100644 index 0000000000..a10ecdc557 --- /dev/null +++ b/risc0/zkvm/sdk/rust/guest/src/mul.rs @@ -0,0 +1,77 @@ +use core::{cell::UnsafeCell, mem}; + +use crate::env::log; +use _alloc::format; +use _alloc::{boxed::Box, vec::Vec}; +use risc0_zkvm::platform::{ + io::{MulDescriptor, GPIO_MUL}, + memory, +}; + +// Current sha descriptor index. +struct CurOutput(UnsafeCell); + +// SAFETY: single threaded environment +unsafe impl Sync for CurOutput {} + +static CUR_OUTPUT: CurOutput = CurOutput(UnsafeCell::new(0)); + +/// Result of multiply goldilocks +pub struct MulGoldilocks([u32; 4]); + +impl MulGoldilocks { + /// Get the result as u64 + pub fn get_u64(&self) -> [u64; 2] { + [ + (self.0[1] as u64) | ((self.0[0] as u64) << 32), + (self.0[3] as u64) | ((self.0[2] as u64) << 32), + ] + } +} + +fn alloc_output() -> *mut MulDescriptor { + // SAFETY: Single threaded and this is the only place we use CUR_DESC. + unsafe { + let cur_desc = CUR_OUTPUT.0.get(); + let ptr = (memory::MUL.start() as *mut MulDescriptor).add(*cur_desc); + *cur_desc += 1; + ptr + } +} + +/// Multiply goldilocks oracle, verification is done separately +pub fn mul_goldilocks(a: &[u64; 2], b: &[u64; 2]) -> &'static MulGoldilocks { + let a0_hi = ((a[0] & 0xFFFFFFFF00000000) >> 32) as u32; + let a0_lo = (a[0] & 0xFFFFFFFF) as u32; + let a1_hi = ((a[1] & 0xFFFFFFFF00000000) >> 32) as u32; + let a1_lo = (a[1] & 0xFFFFFFFF) as u32; + + let b0_hi = ((b[0] & 0xFFFFFFFF00000000) >> 32) as u32; + let b0_lo = (b[0] & 0xFFFFFFFF) as u32; + let b1_hi = ((b[1] & 0xFFFFFFFF00000000) >> 32) as u32; + let b1_lo = (b[1] & 0xFFFFFFFF) as u32; + + let buf = [a0_hi, a0_lo, a1_hi, a1_lo, b0_hi, b0_lo, b1_hi, b1_lo]; + + unsafe { + let alloced = Box::>::new( + mem::MaybeUninit::::uninit(), + ); + let output = (*Box::into_raw(alloced)).as_mut_ptr(); + mul_raw(&buf[..], output); + &*output + } +} + +pub(crate) unsafe fn mul_raw(data: &[u32], result: *mut MulGoldilocks) { + let output_ptr = alloc_output(); + + let ptr = data.as_ptr(); + super::memory_barrier(ptr); + output_ptr.write_volatile(MulDescriptor { + source: ptr as usize, + result: result as usize, + }); + + GPIO_MUL.as_ptr().write_volatile(output_ptr); +} diff --git a/risc0/zkvm/sdk/rust/platform/src/io.rs b/risc0/zkvm/sdk/rust/platform/src/io.rs index b072cb43a1..34a950351f 100644 --- a/risc0/zkvm/sdk/rust/platform/src/io.rs +++ b/risc0/zkvm/sdk/rust/platform/src/io.rs @@ -49,6 +49,8 @@ pub const GPIO_SENDRECV_CHANNEL: Gpio = Gpio::new(0x01F0_0014); pub const GPIO_SENDRECV_SIZE: Gpio = Gpio::new(0x01F0_0018); pub const GPIO_SENDRECV_ADDR: Gpio<*const u8> = Gpio::new(0x01F0_001C); +pub const GPIO_MUL: Gpio<*const MulDescriptor> = Gpio::new(0x01F0_0020); + pub mod addr { pub const GPIO_SHA: u32 = super::GPIO_SHA.addr(); pub const GPIO_COMMIT: u32 = super::GPIO_COMMIT.addr(); @@ -59,6 +61,8 @@ pub mod addr { pub const GPIO_SENDRECV_CHANNEL: u32 = super::GPIO_SENDRECV_CHANNEL.addr(); pub const GPIO_SENDRECV_SIZE: u32 = super::GPIO_SENDRECV_SIZE.addr(); pub const GPIO_SENDRECV_ADDR: u32 = super::GPIO_SENDRECV_ADDR.addr(); + + pub const GPIO_MUL: u32 = super::GPIO_MUL.addr(); } #[repr(C)] @@ -75,6 +79,12 @@ pub struct SHADescriptor { pub digest: usize, } +#[repr(C)] +pub struct MulDescriptor { + pub source: usize, + pub result: usize, +} + #[repr(C)] pub struct GetKeyDescriptor { pub name: u32, @@ -86,3 +96,4 @@ pub struct GetKeyDescriptor { pub const SENDRECV_CHANNEL_INITIAL_INPUT: u32 = 0; pub const SENDRECV_CHANNEL_STDOUT: u32 = 1; pub const SENDRECV_CHANNEL_STDERR: u32 = 2; +pub const SENDRECV_CHANNEL_INITIAL_AUX_INPUT: u32 = 3; diff --git a/risc0/zkvm/sdk/rust/platform/src/memory.rs b/risc0/zkvm/sdk/rust/platform/src/memory.rs index e9034d9d4a..5b2fea9209 100644 --- a/risc0/zkvm/sdk/rust/platform/src/memory.rs +++ b/risc0/zkvm/sdk/rust/platform/src/memory.rs @@ -62,6 +62,7 @@ pub const INPUT: Region = Region::new(0x01E0_0000, mb(1)); pub const GPIO: Region = Region::new(0x01F0_0000, mb(1)); pub const PROG: Region = Region::new(0x0200_0000, mb(10)); pub const SHA: Region = Region::new(0x02A0_0000, mb(1)); -pub const WOM: Region = Region::new(0x02B0_0000, mb(21)); -pub const OUTPUT: Region = Region::new(0x02B0_0000, mb(20)); +pub const MUL: Region = Region::new(0x02B0_0000, mb(1)); +pub const WOM: Region = Region::new(0x02C0_0000, mb(20)); +pub const OUTPUT: Region = Region::new(0x02C0_0000, mb(19)); pub const COMMIT: Region = Region::new(0x03F0_0000, mb(1)); diff --git a/risc0/zkvm/sdk/rust/src/host/ffi.rs b/risc0/zkvm/sdk/rust/src/host/ffi.rs index f09f8131c7..7d209ef6d3 100644 --- a/risc0/zkvm/sdk/rust/src/host/ffi.rs +++ b/risc0/zkvm/sdk/rust/src/host/ffi.rs @@ -102,6 +102,13 @@ extern "C" { len: usize, ); + pub(crate) fn risc0_prover_add_aux_input( + err: *mut RawError, + prover: *mut RawProver, + buf: *const u8, + len: usize, + ); + pub(crate) fn risc0_prover_get_output_buf( err: *mut RawError, prover: *mut RawProver, @@ -388,6 +395,29 @@ impl<'a> Prover<'a> { check(err, || ()) } + /// Provide private input data that is availble to guest-side method code + /// to 'read_aux_input'. + pub fn add_aux_input(&mut self, slice: &[u32]) -> super::Result<()> { + let mut err = RawError::default(); + unsafe { + risc0_prover_add_aux_input( + &mut err, + self.ptr, + slice.as_ptr().cast(), + slice.len() * mem::size_of::(), + ) + }; + check(err, || ()) + } + + /// Allow auxiliary input to be passed in as u8 with zero-copy framework + pub fn add_input_u8_slice_aux(&mut self, slice: &[u8]) { + let mut v: Vec = Vec::new(); + v.resize((slice.len() + 3) / 4, 0); + bytemuck::cast_slice_mut(v.as_mut_slice())[..slice.len()].clone_from_slice(slice); + self.add_aux_input(v.as_slice()).unwrap() + } + /// Compatibility with pure-rust prover pub fn add_input_u8_slice(&mut self, slice: &[u8]) { let mut v: Vec = Vec::new();