Skip to content

Iterate on verifier 2 #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,31 @@
"ios": "cpp",
"iosfwd": "cpp",
"vector": "cpp",
"charconv": "cpp"
"charconv": "cpp",
"__hash_table": "cpp",
"__split_buffer": "cpp",
"__tree": "cpp",
"array": "cpp",
"bitset": "cpp",
"deque": "cpp",
"initializer_list": "cpp",
"list": "cpp",
"map": "cpp",
"queue": "cpp",
"random": "cpp",
"regex": "cpp",
"set": "cpp",
"span": "cpp",
"stack": "cpp",
"string": "cpp",
"string_view": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"valarray": "cpp",
"iterator": "cpp",
"utility": "cpp",
"rope": "cpp",
"slist": "cpp",
"ranges": "cpp"
}
}
9 changes: 9 additions & 0 deletions risc0/zkp/rust/src/core/sha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ impl Digest {
&self.0
}

/// Returns as a slice of be u8
pub fn get_u8(&self) -> [u8; DIGEST_WORDS * 4] {
let mut res: [u8; DIGEST_WORDS * 4] = [0; DIGEST_WORDS * 4];
for i in 0..DIGEST_WORDS {
res[4 * i..][..4].copy_from_slice(&self.0[i].to_be_bytes());
}
res
}

/// Returns a mutable slice of words.
pub fn get_mut(&mut self) -> &mut [u32; DIGEST_WORDS] {
&mut self.0
Expand Down
12 changes: 12 additions & 0 deletions risc0/zkvm/platform/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr size_t kGPIO_GetKey = 0x01F0010;
constexpr size_t kGPIO_SendRecvChannel = 0x01F00014;
constexpr size_t kGPIO_SendRecvSize = 0x01F00018;
constexpr size_t kGPIO_SendRecvAddr = 0x01F0001C;
constexpr size_t kGPIO_Mul = 0x01F00020;

// Standard ZKVM channels; must match zkvm/sdk/rust/platform/src/io.rs.

Expand All @@ -35,6 +36,8 @@ constexpr uint32_t kSendRecvChannel_InitialInput = 0;
constexpr uint32_t kSendRecvChannel_Stdout = 1;
// Write bytes to standard error
constexpr uint32_t kSendRecvChannel_Stderr = 2;
// Request aux tape to the guest
constexpr uint32_t kSendRecvChannel_InitialInputAux = 3;

// To invoke accelerated SHA, the guest writes ShaDescriptor structs
// in sequence to the "SHA" memory region. Once the ShaDescriptor has
Expand Down Expand Up @@ -65,6 +68,15 @@ struct ShaDescriptor {
uint32_t digest;
};

struct MulDescriptor {
// Address of first byte of MUL data to process
// 128 bits for first operand and 128 bits for second
uint32_t source;

// 128 bit result
uint32_t result;
};

inline volatile ShaDescriptor* volatile* GPIO_SHA() {
return reinterpret_cast<volatile ShaDescriptor* volatile*>(kGPIO_SHA);
}
Expand Down
5 changes: 3 additions & 2 deletions risc0/zkvm/platform/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ MEM_REGION(Input, 0x01E00000, k1MB)
MEM_REGION(GPIO, 0x01F00000, k1MB)
MEM_REGION(Prog, 0x02000000, 10 * k1MB)
MEM_REGION(SHA, 0x02A00000, k1MB)
MEM_REGION(WOM, 0x02B00000, 21 * k1MB)
MEM_REGION(Output, 0x02B00000, 20 * k1MB)
MEM_REGION(MUL, 0x02B00000, k1MB)
MEM_REGION(WOM, 0x02C00000, 20 * k1MB)
MEM_REGION(Output, 0x02C00000, 19 * k1MB)
MEM_REGION(Commit, 0x03F00000, k1MB)
// clang-format on

Expand Down
3 changes: 2 additions & 1 deletion risc0/zkvm/platform/risc0.ld
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ MEMORY {
gpio : ORIGIN = 0x01F00000, LENGTH = 1M
prog (X) : ORIGIN = 0x02000000, LENGTH = 10M
sha : ORIGIN = 0x02A00000, LENGTH = 1M
wom : ORIGIN = 0x02B00000, LENGTH = 21M
mul : ORIGIN = 0x02B00000, LENGTH = 1M
wom : ORIGIN = 0x02C00000, LENGTH = 20M
}

SECTIONS {
Expand Down
129 changes: 129 additions & 0 deletions risc0/zkvm/prove/io_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,82 @@

namespace risc0 {

class FpG {
// implement just enough operations to support extension field multiplication
// all values are in mont form
public:
static CONSTSCALAR uint64_t M = 0xFFFFFFFF00000001;
uint64_t val;

private:
static DEVSPEC constexpr uint64_t add(uint64_t a, uint64_t b) {
bool c1 = (M - b) > a;
uint64_t x1 = a - (M - b);
uint32_t adj = uint32_t(0) - uint32_t(c1);
uint64_t res = x1 - uint64_t(adj);
// std::cout << "c1: " << c1 << ", x1: " << x1 << ", adj: " << adj << ", res: " << res <<
// std::endl;
return res;
}

static DEVSPEC constexpr uint64_t sub(uint64_t a, uint64_t b) {
bool c1 = b > a;
uint64_t x1 = a - b;
uint32_t adj = 0 - uint32_t(c1);
return x1 - uint64_t(adj);
}

static DEVSPEC constexpr uint64_t doubleVal(uint64_t a) {
__uint128_t ret = __uint128_t(a) << 1;
uint64_t result = uint64_t(ret);
uint64_t over = uint64_t(ret >> 64);
return result - (M * over);
}

static DEVSPEC constexpr uint64_t montRedCst(__uint128_t n) {
uint64_t xl = uint64_t(n);
uint64_t xh = uint64_t(n >> 64);
bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; // overflow
uint64_t a = xl + (xl << 32);
uint64_t b = a - (a >> 32) - e;
bool c = xh < b;
uint64_t r = xh - b;
uint64_t mont_result = r - (uint32_t(0) - uint32_t(c));
return mont_result;
}

static DEVSPEC constexpr uint64_t mul(uint64_t a, uint64_t b) {
__uint128_t n = __uint128_t(a) * __uint128_t(b);
return montRedCst(n);
}

public:
DEVSPEC constexpr FpG(uint64_t val) : val(val) {}
DEVSPEC constexpr FpG operator+(FpG rhs) const { return FpG(add(val, rhs.val)); }
DEVSPEC constexpr FpG operator-(FpG rhs) const { return FpG(sub(val, rhs.val)); }
DEVSPEC constexpr FpG operator*(FpG rhs) const { return FpG(mul(val, rhs.val)); }
DEVSPEC constexpr FpG doubleVal() const { return FpG(doubleVal(val)); }
};

static std::pair<FpG, FpG> extensionMul(std::pair<FpG, FpG> a, std::pair<FpG, FpG> b) {
FpG a0b0 = a.first * b.first;
FpG a1b1 = a.second * b.second;
FpG first = a0b0 - a1b1.doubleVal();

FpG a0a1 = a.first + a.second;
FpG b0b1 = b.first + b.second;
FpG second = a0a1 * b0b1 - a0b0;

// std::cout << "CPP a: [" << a.first.val << ", " << a.second.val << "]" << std::endl;
// std::cout << "b: [" << b.first.val << ", " << b.second.val << "]" << std::endl;

// std::cout << "a0b0: " << a0b0.val << ", a1b1: " << a1b1.val << ", first: " << first.val
// << ", a0a1: " << a0a1.val << ", b0b1: " << b0b1.val << ", second: "
// << second.val << std::endl;

return std::pair(first, second);
}

static void processSHA(MemoryState& mem, const ShaDescriptor& desc) {
uint16_t type = (desc.typeAndCount & 0xFFFF) >> 4;
uint16_t count = desc.typeAndCount & 0xFFFF;
Expand All @@ -46,6 +122,52 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) {
}
}

static void processMul(MemoryState& mem, const MulDescriptor& desc) {
uint32_t a0_hi = mem.load(desc.source);
LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a0_hi));
uint32_t a0_lo = mem.load(desc.source + 4);
LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a0_lo));
uint32_t a1_hi = mem.load(desc.source + 8);
LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(a1_hi));
uint32_t a1_lo = mem.load(desc.source + 12);
LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(a1_lo));

uint32_t b0_hi = mem.load(desc.source + 16);
LOG(1, "Input[" << hex(4, 2) << "]: " << hex(desc.source + 16) << " -> " << hex(b0_hi));
uint32_t b0_lo = mem.load(desc.source + 20);
LOG(1, "Input[" << hex(5, 2) << "]: " << hex(desc.source + 20) << " -> " << hex(b0_lo));
uint32_t b1_hi = mem.load(desc.source + 24);
LOG(1, "Input[" << hex(6, 2) << "]: " << hex(desc.source + 24) << " -> " << hex(b1_hi));
uint32_t b1_lo = mem.load(desc.source + 28);
LOG(1, "Input[" << hex(7, 2) << "]: " << hex(desc.source + 28) << " -> " << hex(b1_lo));

uint64_t a0 = a0_lo | (uint64_t(a0_hi) << 32);
uint64_t a1 = a1_lo | (uint64_t(a1_hi) << 32);
uint64_t b0 = b0_lo | (uint64_t(b0_hi) << 32);
uint64_t b1 = b1_lo | (uint64_t(b1_hi) << 32);

std::pair<FpG, FpG> a = std::pair(FpG(a0), FpG(a1));
std::pair<FpG, FpG> b = std::pair(FpG(b0), FpG(b1));
std::pair<FpG, FpG> result = extensionMul(a, b);

uint64_t r0 = result.first.val;
uint32_t r0_high = (uint32_t)((r0 & 0xFFFFFFFF00000000LL) >> 32);
uint32_t r0_low = (uint32_t)(r0 & 0xFFFFFFFFLL);

uint64_t r1 = result.second.val;
uint32_t r1_high = (uint32_t)((r1 & 0xFFFFFFFF00000000LL) >> 32);
uint32_t r1_low = (uint32_t)(r1 & 0xFFFFFFFFLL);

LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(r0_high));
mem.store(desc.result, r0_high);
LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(r0_low));
mem.store(desc.result + 4, r0_low);
LOG(1, "Output[" << hex(2, 2) << "]: " << hex(desc.result + 8) << " <- " << hex(r1_high));
mem.store(desc.result + 8, r1_high);
LOG(1, "Output[" << hex(3, 2) << "]: " << hex(desc.result + 12) << " <- " << hex(r1_low));
mem.store(desc.result + 12, r1_low);
}

void IoHandler::onFault(const std::string& msg) {
throw std::runtime_error(msg);
}
Expand All @@ -63,6 +185,13 @@ void MemoryHandler::onInit(MemoryState& mem) {
void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) {
LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value));
switch (addr) {
case kGPIO_Mul: {
LOG(1, "MemoryHandler::onWrite> GPIO_MUL");
MulDescriptor desc;
mem.loadRegion(value, &desc, sizeof(desc));
processMul(mem, desc);
break;
}
case kGPIO_SHA: {
LOG(1, "MemoryHandler::onWrite> GPIO_SHA");
ShaDescriptor desc;
Expand Down
7 changes: 7 additions & 0 deletions risc0/zkvm/sdk/cpp/host/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ void risc0_prover_add_input(risc0_error* err, risc0_prover* ptr, const uint8_t*
ffi_wrap_void(err, [&] { ptr->prover->writeInput(buf, len); });
}

void risc0_prover_add_aux_input(risc0_error* err,
risc0_prover* ptr,
const uint8_t* buf,
size_t len) {
ffi_wrap_void(err, [&] { ptr->prover->writeInputAux(buf, len); });
}

const void* risc0_prover_get_output_buf(risc0_error* err, const risc0_prover* ptr) {
return ffi_wrap<const void*>(err, nullptr, [&] { return ptr->prover->getOutput().data(); });
}
Expand Down
5 changes: 5 additions & 0 deletions risc0/zkvm/sdk/cpp/host/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ void risc0_prover_free(risc0_error* err, risc0_prover* ptr);

void risc0_prover_add_input(risc0_error* err, risc0_prover* ptr, const uint8_t* buf, size_t len);

void risc0_prover_add_aux_input(risc0_error* err,
risc0_prover* ptr,
const uint8_t* buf,
size_t len);

size_t risc0_prover_get_num_outputs(risc0_error* err, risc0_prover* ptr);

const void* risc0_prover_get_output_buf(risc0_error* err, const risc0_prover* ptr);
Expand Down
35 changes: 35 additions & 0 deletions risc0/zkvm/sdk/cpp/host/receipt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct Prover::Impl : public IoHandler {
, outputStream(outputBuffer)
, commitStream(commitBuffer)
, inputWriter(inputStream)
, inputWriterAux(inputStreamAux)
, outputReader(outputStream)
, commitReader(commitStream) {
// Set default handlers:
Expand All @@ -83,6 +84,13 @@ struct Prover::Impl : public IoHandler {
LOG(1, "IoHandler::InitialInput, " << input.size() << " bytes");
return input;
});
setSendRecvHandler(
kSendRecvChannel_InitialInputAux, [this](uint32_t, const BufferU8& buf) -> BufferU8 {
const uint8_t* byte_ptr = reinterpret_cast<const uint8_t*>(inputStreamAux.vec.data());
BufferU8 input(byte_ptr, byte_ptr + inputStreamAux.vec.size() * sizeof(uint32_t));
LOG(1, "IoHandler::InitialInputAux, " << input.size() << " bytes");
return input;
});
}

virtual ~Impl() {}
Expand Down Expand Up @@ -116,9 +124,11 @@ struct Prover::Impl : public IoHandler {
BufferU8 outputBuffer;
BufferU8 commitBuffer;
VectorStreamWriter inputStream;
VectorStreamWriter inputStreamAux;
CheckedStreamReader outputStream;
CheckedStreamReader commitStream;
ArchiveWriter<VectorStreamWriter> inputWriter;
ArchiveWriter<VectorStreamWriter> inputWriterAux;
ArchiveReader<CheckedStreamReader> outputReader;
ArchiveReader<CheckedStreamReader> commitReader;

Expand Down Expand Up @@ -219,6 +229,31 @@ void Prover::writeInput(const void* ptr, size_t size) {
}
}

void Prover::writeInputAux(const void* ptr, size_t size) {
LOG(1, "Prover::writeInputAux> size: " << size);
const uint8_t* ptr_u8 = static_cast<const uint8_t*>(ptr);
while (size >= sizeof(uint32_t)) {
uint32_t word = 0;
word |= *ptr_u8++;
word |= *ptr_u8++ << 8;
word |= *ptr_u8++ << 16;
word |= *ptr_u8++ << 24;
LOG(1, " write_word: " << hex(word));
impl->inputStreamAux.write_word(word);
size -= sizeof(uint32_t);
}

if (size) {
LOG(1, " tail: " << size);
uint32_t word = 0;
for (size_t i = 0; i < size; i++) {
word |= *ptr_u8++ << (8 * i);
}
LOG(1, " write_word: " << hex(word));
impl->inputStreamAux.write_word(word);
}
}

void Prover::setSendRecvHandler(
uint32_t channelId,
const std::function<BufferU8(uint32_t /* channelId*/, const BufferU8&)>& handler) {
Expand Down
2 changes: 2 additions & 0 deletions risc0/zkvm/sdk/cpp/host/receipt.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class Prover {

void writeInput(const void* ptr, size_t size);

void writeInputAux(const void* ptr, size_t size);

template <typename T> void writeInput(const T& obj) { getInputWriter().transfer(obj); }

const BufferU8& getOutput();
Expand Down
Loading