Skip to content

Commit ddb3773

Browse files
authored
Parallel comemo & optimizations (#5)
1 parent 2b3b8ee commit ddb3773

File tree

12 files changed

+648
-372
lines changed

12 files changed

+648
-372
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ jobs:
77
steps:
88
- uses: actions/checkout@v3
99
- uses: dtolnay/rust-toolchain@stable
10-
- run: cargo build
11-
- run: cargo test
10+
- run: cargo build --all-features
11+
- run: cargo test --all-features

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.vscode
22
.DS_Store
33
/target
4+
macros/target
45
Cargo.lock

Cargo.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@ license = "MIT OR Apache-2.0"
1010
categories = ["caching"]
1111
keywords = ["incremental", "memoization", "tracked", "constraints"]
1212

13+
[features]
14+
default = []
15+
testing = []
16+
1317
[dependencies]
1418
comemo-macros = { version = "0.3.1", path = "macros" }
19+
once_cell = "1.18"
20+
parking_lot = "0.12"
1521
siphasher = "1"
22+
23+
[dev-dependencies]
24+
serial_test = "2.0.0"
25+
26+
[[test]]
27+
name = "tests"
28+
path = "tests/tests.rs"
29+
required-features = ["testing"]

macros/src/memoize.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
77
};
88

99
// Preprocess and validate the function.
10-
let function = prepare(&item)?;
10+
let function = prepare(item)?;
1111

1212
// Rewrite the function's body to memoize it.
1313
process(&function)
@@ -23,7 +23,7 @@ struct Function {
2323
/// An argument to a memoized function.
2424
enum Argument {
2525
Receiver(syn::Token![self]),
26-
Ident(Option<syn::Token![mut]>, syn::Ident),
26+
Ident(Box<syn::Type>, Option<syn::Token![mut]>, syn::Ident),
2727
}
2828

2929
/// Preprocess and validate a function.
@@ -71,7 +71,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result<Argument> {
7171
bail!(typed.ty, "memoized functions cannot have mutable parameters")
7272
}
7373

74-
Argument::Ident(mutability.clone(), ident.clone())
74+
Argument::Ident(typed.ty.clone(), *mutability, ident.clone())
7575
}
7676
})
7777
}
@@ -82,7 +82,7 @@ fn process(function: &Function) -> Result<TokenStream> {
8282
let bounds = function.args.iter().map(|arg| {
8383
let val = match arg {
8484
Argument::Receiver(token) => quote! { #token },
85-
Argument::Ident(_, ident) => quote! { #ident },
85+
Argument::Ident(_, _, ident) => quote! { #ident },
8686
};
8787
quote_spanned! { function.item.span() =>
8888
::comemo::internal::assert_hashable_or_trackable(&#val);
@@ -94,14 +94,20 @@ fn process(function: &Function) -> Result<TokenStream> {
9494
Argument::Receiver(token) => quote! {
9595
::comemo::internal::hash(&#token)
9696
},
97-
Argument::Ident(_, ident) => quote! { #ident },
97+
Argument::Ident(_, _, ident) => quote! { #ident },
9898
});
9999
let arg_tuple = quote! { (#(#args,)*) };
100100

101+
let arg_tys = function.args.iter().map(|arg| match arg {
102+
Argument::Receiver(_) => quote! { () },
103+
Argument::Ident(ty, _, _) => quote! { #ty },
104+
});
105+
let arg_ty_tuple = quote! { (#(#arg_tys,)*) };
106+
101107
// Construct a tuple for all parameters.
102108
let params = function.args.iter().map(|arg| match arg {
103109
Argument::Receiver(_) => quote! { _ },
104-
Argument::Ident(mutability, ident) => quote! { #mutability #ident },
110+
Argument::Ident(_, mutability, ident) => quote! { #mutability #ident },
105111
});
106112
let param_tuple = quote! { (#(#params,)*) };
107113

@@ -118,14 +124,20 @@ fn process(function: &Function) -> Result<TokenStream> {
118124
ident.mutability = None;
119125
}
120126

121-
let unique = quote! { __ComemoUnique };
122127
wrapped.block = parse_quote! { {
123-
struct #unique;
128+
static __CACHE: ::comemo::internal::Cache<
129+
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
130+
#output,
131+
> = ::comemo::internal::Cache::new(|| {
132+
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
133+
::core::default::Default::default()
134+
});
135+
124136
#(#bounds;)*
125137
::comemo::internal::memoized(
126-
::core::any::TypeId::of::<#unique>(),
127138
::comemo::internal::Args(#arg_tuple),
128139
&::core::default::Default::default(),
140+
&__CACHE,
129141
#closure,
130142
)
131143
} };

macros/src/track.rs

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,38 @@ pub fn expand(item: &syn::Item) -> Result<TokenStream> {
2020
}
2121

2222
for item in &item.items {
23-
methods.push(prepare_impl_method(&item)?);
23+
methods.push(prepare_impl_method(item)?);
2424
}
2525

2626
let ty = item.self_ty.as_ref().clone();
2727
(ty, &item.generics, None)
2828
}
2929
syn::Item::Trait(item) => {
30-
for param in item.generics.params.iter() {
31-
bail!(param, "tracked traits cannot be generic")
30+
if let Some(first) = item.generics.params.first() {
31+
bail!(first, "tracked traits cannot be generic")
3232
}
3333

3434
for item in &item.items {
35-
methods.push(prepare_trait_method(&item)?);
35+
methods.push(prepare_trait_method(item)?);
3636
}
3737

3838
let name = &item.ident;
3939
let ty = parse_quote! { dyn #name + '__comemo_dynamic };
40-
(ty, &item.generics, Some(name.clone()))
40+
(ty, &item.generics, Some(item.ident.clone()))
4141
}
4242
_ => bail!(item, "`track` can only be applied to impl blocks and traits"),
4343
};
4444

4545
// Produce the necessary items for the type to become trackable.
46+
let variants = create_variants(&methods);
4647
let scope = create(&ty, generics, trait_, &methods)?;
4748

4849
Ok(quote! {
4950
#item
50-
const _: () = { #scope };
51+
const _: () = {
52+
#variants
53+
#scope
54+
};
5155
})
5256
}
5357

@@ -175,6 +179,43 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result<Method>
175179
})
176180
}
177181

182+
/// Produces the variants for the constraint.
183+
fn create_variants(methods: &[Method]) -> TokenStream {
184+
let variants = methods.iter().map(create_variant);
185+
let is_mutable_variants = methods.iter().map(|m| {
186+
let name = &m.sig.ident;
187+
let mutable = m.mutable;
188+
quote! { __ComemoVariant::#name(..) => #mutable }
189+
});
190+
191+
let is_mutable = (!methods.is_empty())
192+
.then(|| {
193+
quote! {
194+
match &self.0 {
195+
#(#is_mutable_variants),*
196+
}
197+
}
198+
})
199+
.unwrap_or_else(|| quote! { false });
200+
201+
quote! {
202+
#[derive(Clone, PartialEq, Hash)]
203+
pub struct __ComemoCall(__ComemoVariant);
204+
205+
impl ::comemo::internal::Call for __ComemoCall {
206+
fn is_mutable(&self) -> bool {
207+
#is_mutable
208+
}
209+
}
210+
211+
#[derive(Clone, PartialEq, Hash)]
212+
#[allow(non_camel_case_types)]
213+
enum __ComemoVariant {
214+
#(#variants,)*
215+
}
216+
}
217+
}
218+
178219
/// Produce the necessary items for a type to become trackable.
179220
fn create(
180221
ty: &syn::Type,
@@ -229,26 +270,32 @@ fn create(
229270
};
230271

231272
// Prepare replying.
273+
let immutable = methods.iter().all(|m| !m.mutable);
232274
let replays = methods.iter().map(create_replay);
233-
let replay = methods.iter().any(|m| m.mutable).then(|| {
275+
let replay = (!immutable).then(|| {
234276
quote! {
235277
constraint.replay(|call| match &call.0 { #(#replays,)* });
236278
}
237279
});
238280

239281
// Prepare variants and wrapper methods.
240-
let variants = methods.iter().map(create_variant);
241282
let wrapper_methods = methods
242283
.iter()
243284
.filter(|m| !m.mutable)
244285
.map(|m| create_wrapper(m, false));
245286
let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true));
246287

288+
let constraint = if immutable {
289+
quote! { ImmutableConstraint }
290+
} else {
291+
quote! { MutableConstraint }
292+
};
293+
247294
Ok(quote! {
248-
impl #impl_params ::comemo::Track for #ty #where_clause {}
295+
impl #impl_params ::comemo::Track for #ty #where_clause {}
249296

250-
impl #impl_params ::comemo::Validate for #ty #where_clause {
251-
type Constraint = ::comemo::internal::Constraint<__ComemoCall>;
297+
impl #impl_params ::comemo::Validate for #ty #where_clause {
298+
type Constraint = ::comemo::internal::#constraint<__ComemoCall>;
252299

253300
#[inline]
254301
fn validate(&self, constraint: &Self::Constraint) -> bool {
@@ -267,15 +314,6 @@ fn create(
267314
}
268315
}
269316

270-
#[derive(Clone, PartialEq, Hash)]
271-
pub struct __ComemoCall(__ComemoVariant);
272-
273-
#[derive(Clone, PartialEq, Hash)]
274-
#[allow(non_camel_case_types)]
275-
enum __ComemoVariant {
276-
#(#variants,)*
277-
}
278-
279317
#[doc(hidden)]
280318
impl #impl_params ::comemo::internal::Surfaces for #ty #where_clause {
281319
type Surface<#t> = __ComemoSurface #type_params_t where Self: #t;
@@ -323,7 +361,6 @@ fn create(
323361
impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t {
324362
#(#wrapper_methods_mut)*
325363
}
326-
327364
})
328365
}
329366

@@ -370,10 +407,9 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
370407
let vis = &method.vis;
371408
let sig = &method.sig;
372409
let args = &method.args;
373-
let mutable = method.mutable;
374410
let to_parts = if !tracked_mut {
375411
quote! { to_parts_ref(self.0) }
376-
} else if !mutable {
412+
} else if !method.mutable {
377413
quote! { to_parts_mut_ref(&self.0) }
378414
} else {
379415
quote! { to_parts_mut_mut(&mut self.0) }
@@ -389,7 +425,6 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
389425
constraint.push(
390426
__ComemoCall(__comemo_variant),
391427
::comemo::internal::hash(&output),
392-
#mutable,
393428
);
394429
}
395430
output

src/accelerate.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use std::collections::HashMap;
2+
use std::sync::atomic::{AtomicUsize, Ordering};
3+
4+
use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard};
5+
6+
/// The global list of currently alive accelerators.
7+
static ACCELERATORS: RwLock<(usize, Vec<Accelerator>)> = RwLock::new((0, Vec::new()));
8+
9+
/// The current ID of the accelerator.
10+
static ID: AtomicUsize = AtomicUsize::new(0);
11+
12+
/// The type of each individual accelerator.
13+
///
14+
/// Maps from call hashes to return hashes.
15+
type Accelerator = Mutex<HashMap<u128, u128>>;
16+
17+
/// Generate a new accelerator.
18+
pub fn id() -> usize {
19+
// Get the next ID.
20+
ID.fetch_add(1, Ordering::SeqCst)
21+
}
22+
23+
/// Evict the accelerators.
24+
pub fn evict() {
25+
let mut accelerators = ACCELERATORS.write();
26+
let (offset, vec) = &mut *accelerators;
27+
28+
// Update the offset.
29+
*offset = ID.load(Ordering::SeqCst);
30+
31+
// Clear all accelerators while keeping the memory allocated.
32+
vec.iter_mut().for_each(|accelerator| accelerator.lock().clear())
33+
}
34+
35+
/// Get an accelerator by ID.
36+
pub fn get(id: usize) -> Option<MappedRwLockReadGuard<'static, Accelerator>> {
37+
// We always lock the accelerators, as we need to make sure that the
38+
// accelerator is not removed while we are reading it.
39+
let mut accelerators = ACCELERATORS.read();
40+
41+
let mut i = id.checked_sub(accelerators.0)?;
42+
if i >= accelerators.1.len() {
43+
drop(accelerators);
44+
resize(i + 1);
45+
accelerators = ACCELERATORS.read();
46+
47+
// Because we release the lock before resizing the accelerator, we need
48+
// to check again whether the ID is still valid because another thread
49+
// might evicted the cache.
50+
i = id.checked_sub(accelerators.0)?;
51+
}
52+
53+
Some(RwLockReadGuard::map(accelerators, move |(_, vec)| &vec[i]))
54+
}
55+
56+
/// Adjusts the amount of accelerators.
57+
#[cold]
58+
fn resize(len: usize) {
59+
let mut pair = ACCELERATORS.write();
60+
if len > pair.1.len() {
61+
pair.1.resize_with(len, || Mutex::new(HashMap::new()));
62+
}
63+
}

0 commit comments

Comments
 (0)