Skip to content

Commit

Permalink
CairoType refactor & derive proc-macro
Browse files Browse the repository at this point in the history
  • Loading branch information
Okm165 committed Feb 5, 2025
1 parent b343f7a commit b6cf1bd
Show file tree
Hide file tree
Showing 45 changed files with 1,503 additions and 9,344 deletions.
123 changes: 61 additions & 62 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
resolver = "2"

members = [
"crates/cairo_type_derive",
"crates/dry_hint_processor",
"crates/dry_run",
"crates/fetcher",
Expand All @@ -23,7 +24,6 @@ axum = { version = "0.8", features = ["tracing"] }
bincode = { version = "2.0.0-rc.3", default-features = false, features = ["serde"]}
cairo-lang-casm = { version = "2.10.0-rc.1", default-features = false }
cairo-lang-starknet-classes = "2.10.0-rc.1"
cairo-type-derive = { git = "https://github.com/keep-starknet-strange/snos.git", rev = "35a300a10d2107482ada440b5025ee2f651afbd4" }
cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm", rev = "62804bcbf58b436a8986e7da0ee80333400b1ffb", features = ["extensive_hints", "clap", "cairo-1-hints"] }
clap = { version = "4.3.10", features = ["derive"] }
eth-trie-proofs = { git = "https://github.com/HerodotusDev/eth-trie-proofs.git" }
Expand Down Expand Up @@ -57,6 +57,7 @@ utoipa = { version = "5.3.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "9", features = ["axum"] }
version-compare = "=0.0.11"

cairo_type_derive = { path = "crates/cairo_type_derive" }
dry_hint_processor = { path = "crates/dry_hint_processor" }
eth_essentials_cairo_vm_hints = { path = "packages/eth_essentials/cairo_vm_hints" }
fetcher = { path = "crates/fetcher" }
Expand All @@ -65,4 +66,4 @@ indexer = { path = "crates/indexer" }
pathfinder_gateway_types = { git = "https://github.com/eqlabs/pathfinder", package = "starknet-gateway-types" }
sound_hint_processor = { path = "crates/sound_hint_processor" }
syscall_handler = { path = "crates/syscall_handler" }
types = { path = "crates/types" }
types = { path = "crates/types" }
12 changes: 12 additions & 0 deletions crates/cairo_type_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "cairo_type_derive"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
quote = "1.0.35"
syn = "2.0.48"
proc-macro2 = "1.0.78"
153 changes: 153 additions & 0 deletions crates/cairo_type_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
extern crate proc_macro;

use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};

#[proc_macro_derive(CairoType)]
pub fn cairo_type_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree
let input = parse_macro_input!(input as DeriveInput);

// Get the identifier of the struct
let struct_ident = &input.ident;

// Generate code to implement the trait
let expanded = match &input.data {
Data::Struct(data_struct) => match &data_struct.fields {
Fields::Named(fields) => {
let field_names_read = fields.named.iter().map(|f| &f.ident);
let field_names_write = field_names_read.clone();
let n_fields = field_names_read.clone().count();
let field_values = field_names_read.clone().enumerate().map(|(index, field_name)| {
quote! {
let #field_name = vm.get_integer((address + #index)?)?.into_owned();
}
});

quote! {
impl CairoType for #struct_ident {
fn from_memory(vm: &VirtualMachine, address: Relocatable) -> Result<Self, MemoryError> {
#(#field_values)*
Ok(Self {
#( #field_names_read ),*
})
}
fn to_memory(&self, vm: &mut VirtualMachine, address: Relocatable) -> Result<Relocatable, MemoryError> {
let mut offset = 0;
#(vm.insert_value((address + offset)?, &self.#field_names_write)?; offset += 1;)*

Ok((address + offset)?)
}

fn n_fields(vm: &VirtualMachine, address: Relocatable) -> Result<usize, MemoryError> {
Ok(#n_fields)
}
}
}
}
Fields::Unnamed(_) | Fields::Unit => {
// Unsupported field types
quote! {
compile_error!("CairoType only supports structs with named fields");
}
}
},
Data::Enum(_) | Data::Union(_) => {
// Unsupported data types
quote! {
compile_error!("CairoType only supports structs");
}
}
};

// Convert the generated code into a TokenStream and return it
proc_macro::TokenStream::from(expanded)
}

fn field_size(field: &syn::Field) -> proc_macro2::TokenStream {
if let Type::Path(type_path) = &field.ty {
if let Some(segment) = type_path.path.segments.last() {
let type_name = &segment.ident;
if type_name == "Felt252" || type_name == "Relocatable" {
quote! { 1 }
} else {
quote! { #type_name::cairo_size() }
}
} else {
let field_name = field.ident.as_ref().unwrap();
quote! {
compile_error!("Could not determine the size of {}.", #field_name);
}
}
} else {
quote! {
compile_error!("Could not determine the size of all fields in the struct. This derive macro is only compatible with Felt252 fields.");
}
}
}

/// Provides a method to compute the address of each field
#[proc_macro_derive(FieldOffsetGetters)]
pub fn get_field_offsets_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree
let input = parse_macro_input!(input as DeriveInput);

// Get the identifier of the struct
let struct_ident = &input.ident;

// Generate code to implement the trait
let getters = match &input.data {
Data::Struct(data_struct) => {
// Extract fields' names and types
let fields = match &data_struct.fields {
Fields::Named(fields) => &fields.named,
_ => {
return quote! {
compile_error!("FieldOffsetGetters only supports structs with named fields");
}
.into();
}
};

let mut field_sizes: Vec<proc_macro2::TokenStream> = vec![];
let mut get_field_offset_methods: Vec<proc_macro2::TokenStream> = vec![];

for field in fields {
let field_name = field.ident.as_ref().expect("Expected named field");
let offset_fn_name = format_ident!("{}_offset", field_name);

let rendered_offset_impl = if field_sizes.is_empty() {
quote! { 0 }
} else {
quote! { #(#field_sizes)+* }
};

let get_offset_method = quote! {
pub fn #offset_fn_name() -> usize {
#rendered_offset_impl
}
};
get_field_offset_methods.push(get_offset_method);
field_sizes.push(field_size(field));
}

// Combine all setter methods
quote! {
impl #struct_ident {
#( #get_field_offset_methods )*
pub fn cairo_size() -> usize {
#(#field_sizes)+*
}
}
}
}
_ => {
quote! {
compile_error!("FieldOffsetGetters only supports structs");
}
}
};

// Convert the generated code into a TokenStream and return it
getters.into()
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl CallHandler for AccountCallHandler {

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
ret.try_into()
.map_err(|e| SyscallExecutionError::InternalError(format!("{}", e).into()))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl CallHandler for HeaderCallHandler {

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
ret.try_into()
.map_err(|e| SyscallExecutionError::InternalError(format!("{}", e).into()))
}
Expand Down
10 changes: 5 additions & 5 deletions crates/dry_hint_processor/src/syscall_handler/evm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,39 +70,39 @@ impl SyscallHandler for CallContractHandler {
let result = header::HeaderCallHandler.handle(key.clone(), function_id, vm).await?;
self.key_set.insert(DryRunKey::Header(key));
result.to_memory(vm, retdata_end)?;
retdata_end += <header::HeaderCallHandler as CallHandler>::CallHandlerResult::n_fields();
retdata_end += <header::HeaderCallHandler as CallHandler>::CallHandlerResult::n_fields(vm, retdata_end)?;
}
CallHandlerId::Account => {
let key = account::AccountCallHandler::derive_key(vm, &mut calldata)?;
let function_id = account::AccountCallHandler::derive_id(request.selector)?;
let result = account::AccountCallHandler.handle(key.clone(), function_id, vm).await?;
self.key_set.insert(DryRunKey::Account(key));
result.to_memory(vm, retdata_end)?;
retdata_end += <account::AccountCallHandler as CallHandler>::CallHandlerResult::n_fields();
retdata_end += <account::AccountCallHandler as CallHandler>::CallHandlerResult::n_fields(vm, retdata_end)?;
}
CallHandlerId::Storage => {
let key = storage::StorageCallHandler::derive_key(vm, &mut calldata)?;
let function_id = storage::StorageCallHandler::derive_id(request.selector)?;
let result = storage::StorageCallHandler.handle(key.clone(), function_id, vm).await?;
self.key_set.insert(DryRunKey::Storage(key));
result.to_memory(vm, retdata_end)?;
retdata_end += <storage::StorageCallHandler as CallHandler>::CallHandlerResult::n_fields();
retdata_end += <storage::StorageCallHandler as CallHandler>::CallHandlerResult::n_fields(vm, retdata_end)?;
}
CallHandlerId::Transaction => {
let key = transaction::TransactionCallHandler::derive_key(vm, &mut calldata)?;
let function_id = transaction::TransactionCallHandler::derive_id(request.selector)?;
let result = transaction::TransactionCallHandler.handle(key.clone(), function_id, vm).await?;
self.key_set.insert(DryRunKey::Tx(key));
result.to_memory(vm, retdata_end)?;
retdata_end += <transaction::TransactionCallHandler as CallHandler>::CallHandlerResult::n_fields();
retdata_end += <transaction::TransactionCallHandler as CallHandler>::CallHandlerResult::n_fields(vm, retdata_end)?;
}
CallHandlerId::Receipt => {
let key = receipt::ReceiptCallHandler::derive_key(vm, &mut calldata)?;
let function_id = receipt::ReceiptCallHandler::derive_id(request.selector)?;
let result = receipt::ReceiptCallHandler.handle(key.clone(), function_id, vm).await?;
self.key_set.insert(DryRunKey::Receipt(key));
result.to_memory(vm, retdata_end)?;
retdata_end += <receipt::ReceiptCallHandler as CallHandler>::CallHandlerResult::n_fields();
retdata_end += <receipt::ReceiptCallHandler as CallHandler>::CallHandlerResult::n_fields(vm, retdata_end)?;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl CallHandler for ReceiptCallHandler {

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
ret.try_into()
.map_err(|e| SyscallExecutionError::InternalError(format!("{}", e).into()))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl CallHandler for StorageCallHandler {

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
ret.try_into()
.map_err(|e| SyscallExecutionError::InternalError(format!("{}", e).into()))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl CallHandler for TransactionCallHandler {

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
ret.try_into()
.map_err(|e| SyscallExecutionError::InternalError(format!("{}", e).into()))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use syscall_handler::{traits::CallHandler, SyscallExecutionError, SyscallResult}
use types::{
cairo::{
starknet::header::{Block, FunctionId, StarknetBlock},
structs::Felt,
structs::CairoFelt,
traits::CairoType,
},
keys::starknet::header::{CairoKey, Key},
Expand All @@ -20,11 +20,11 @@ pub struct HeaderCallHandler;
impl CallHandler for HeaderCallHandler {
type Key = Key;
type Id = FunctionId;
type CallHandlerResult = Felt;
type CallHandlerResult = CairoFelt;

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
ret.try_into()
.map_err(|e| SyscallExecutionError::InternalError(format!("{}", e).into()))
}
Expand Down
4 changes: 2 additions & 2 deletions crates/dry_hint_processor/src/syscall_handler/starknet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ impl SyscallHandler for CallContractHandler {
let result = HeaderCallHandler.handle(key.clone(), function_id, vm).await?;
self.key_set.insert(DryRunKey::Header(key));
result.to_memory(vm, retdata_end)?;
retdata_end += <HeaderCallHandler as CallHandler>::CallHandlerResult::n_fields();
retdata_end += <HeaderCallHandler as CallHandler>::CallHandlerResult::n_fields(vm, retdata_end)?;
}
CallHandlerId::Storage => {
let key = StorageCallHandler::derive_key(vm, &mut calldata)?;
let function_id = StorageCallHandler::derive_id(request.selector)?;
let result = StorageCallHandler.handle(key.clone(), function_id, vm).await?;
self.key_set.insert(DryRunKey::Storage(key));
result.to_memory(vm, retdata_end)?;
retdata_end += <StorageCallHandler as CallHandler>::CallHandlerResult::n_fields();
retdata_end += <StorageCallHandler as CallHandler>::CallHandlerResult::n_fields(vm, retdata_end)?;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use starknet::{
};
use syscall_handler::{traits::CallHandler, SyscallExecutionError, SyscallResult};
use types::{
cairo::{evm::storage::FunctionId, structs::Felt, traits::CairoType},
cairo::{evm::storage::FunctionId, structs::CairoFelt, traits::CairoType},
keys::starknet::storage::{CairoKey, Key},
RPC_URL_STARKNET,
};
Expand All @@ -22,11 +22,11 @@ pub struct StorageCallHandler;
impl CallHandler for StorageCallHandler {
type Key = Key;
type Id = FunctionId;
type CallHandlerResult = Felt;
type CallHandlerResult = CairoFelt;

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
ret.try_into()
.map_err(|e| SyscallExecutionError::InternalError(format!("{}", e).into()))
}
Expand All @@ -49,7 +49,7 @@ impl CallHandler for StorageCallHandler {
FunctionId::Storage => provider
.get_storage_at::<Felt252, Felt252, BlockId>(key.address, key.storage_slot, block_id)
.await
.map(Felt::from),
.map(CairoFelt::from),
}
.map_err(|e| SyscallExecutionError::InternalError(e.to_string().into()))?;

Expand Down
2 changes: 1 addition & 1 deletion crates/hints/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ edition = "2021"
alloy-rlp.workspace = true
alloy.workspace = true
bincode.workspace = true
cairo_type_derive.workspace = true
cairo-lang-casm.workspace = true
cairo-lang-starknet-classes.workspace = true
cairo-type-derive.workspace = true
cairo-vm.workspace = true
eth_essentials_cairo_vm_hints.workspace = true
hex.workspace = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl CallHandler for AccountCallHandler {

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
Ok(ret)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl CallHandler for HeaderCallHandler {

fn derive_key(vm: &VirtualMachine, ptr: &mut Relocatable) -> SyscallResult<Self::Key> {
let ret = CairoKey::from_memory(vm, *ptr)?;
*ptr = (*ptr + CairoKey::n_fields())?;
*ptr = (*ptr + CairoKey::n_fields(vm, *ptr)?)?;
Ok(ret)
}

Expand Down
Loading

0 comments on commit b6cf1bd

Please sign in to comment.