diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index fed0594a6..d6d131abf 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -118,8 +118,7 @@ macro_rules! setup_input_struct { } } - pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { - let zalsa_mut = db.zalsa_mut(); + pub fn ingredient_mut(zalsa_mut: &mut $zalsa::Zalsa) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { zalsa_mut.new_revision(); let index = zalsa_mut.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); @@ -208,8 +207,10 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient_(db.zalsa()).field( - db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + let fields = $Configuration::ingredient_(zalsa).field( + zalsa, + zalsa_local, self, $field_index, ); @@ -228,7 +229,8 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let (ingredient, revision) = $Configuration::ingredient_mut(db.as_dyn_database_mut()); + let zalsa = db.zalsa_mut(); + let (ingredient, revision) = $Configuration::ingredient_mut(zalsa); $zalsa::input::SetterImpl::new( revision, self, @@ -267,7 +269,8 @@ macro_rules! setup_input_struct { $(for<'__trivial_bounds> $field_ty: std::fmt::Debug),* { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).leak_fields(db, this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).leak_fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this)); $( @@ -296,11 +299,11 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let current_revision = zalsa.current_revision(); let ingredient = $Configuration::ingredient_(zalsa); let (fields, revision, durabilities) = builder::builder_into_inner(self, current_revision); - ingredient.new_input(db.as_dyn_database(), fields, revision, durabilities) + ingredient.new_input(zalsa, zalsa_local, fields, revision, durabilities) } } diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 9b121fdc1..b637586e5 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -149,15 +149,11 @@ macro_rules! setup_interned_struct { } impl $Configuration { - pub fn ingredient(db: &Db) -> &$zalsa_struct::IngredientImpl - where - Db: ?Sized + $zalsa::Database, + pub fn ingredient(zalsa: &$zalsa::Zalsa) -> &$zalsa_struct::IngredientImpl { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); - let zalsa = db.zalsa(); - // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only // ingredient created by our jar is the struct ingredient. unsafe { @@ -239,7 +235,8 @@ macro_rules! setup_interned_struct { $field_ty: $zalsa::interned::HashEqLike<$indexed_ty>, )* { - $Configuration::ingredient(db).intern(db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + $Configuration::ingredient(zalsa).intern(zalsa, zalsa_local, StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default()), |_, data| ($($zalsa::interned::Lookup::into_owned(data.$field_index),)*)) } @@ -250,7 +247,8 @@ macro_rules! setup_interned_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), self); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, self); $zalsa::return_mode_expression!( $field_option, $field_ty, @@ -262,7 +260,8 @@ macro_rules! setup_interned_struct { /// Default debug formatting for this struct (may be useful if you define your own `Debug` impl) pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).fields(db.as_dyn_database(), this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient(zalsa).fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); $( let f = f.field(stringify!($field_id), &fields.$field_index); diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index a79007592..77325a484 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -175,17 +175,21 @@ macro_rules! setup_tracked_fn { impl $Configuration { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); + Self::fn_ingredient_(db, zalsa) + } + #[inline] + fn fn_ingredient_<'z>(db: &dyn $Db, zalsa: &'z $zalsa::Zalsa) -> &'z $zalsa::function::IngredientImpl<$Configuration> { // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the first // ingredient created by our jar is the function ingredient. unsafe { $FN_CACHE.get_or_create(zalsa, || zalsa.lookup_jar_by_type::<$fn_name>()) } - .get_or_init(|| ::zalsa_register_downcaster(db)) + .get_or_init(|| *::zalsa_register_downcaster(db)) } pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl { - let view = ::zalsa_register_downcaster(db); + let view = *::zalsa_register_downcaster(db); let zalsa_mut = db.zalsa_mut(); let index = zalsa_mut.lookup_jar_by_type::<$fn_name>(); let (ingredient, _) = zalsa_mut.lookup_ingredient_mut(index); @@ -199,7 +203,12 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { let zalsa = db.zalsa(); - + Self::intern_ingredient_(zalsa) + } + #[inline] + fn intern_ingredient_<'z>( + zalsa: &'z $zalsa::Zalsa + ) -> &'z $zalsa::interned::IngredientImpl<$Configuration> { // SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the second // ingredient created by our jar is the interned ingredient (given `needs_interner`). unsafe { @@ -257,12 +266,12 @@ macro_rules! setup_tracked_fn { $($cycle_recovery_fn)*(db, value, count, $($input_id),*) } - fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { + fn id_to_input<$db_lt>(zalsa: &$db_lt $zalsa::Zalsa, key: salsa::Id) -> Self::Input<$db_lt> { $zalsa::macro_if! { if $needs_interner { - $Configuration::intern_ingredient(db).data(db.as_dyn_database(), key).clone() + $Configuration::intern_ingredient_(zalsa).data(zalsa, key).clone() } else { - $zalsa::FromIdWithDb::from_id(key, db.zalsa()) + $zalsa::FromIdWithDb::from_id(key, zalsa) } } } @@ -340,9 +349,10 @@ macro_rules! setup_tracked_fn { ) -> Vec<&$db_lt A> { use salsa::plumbing as $zalsa; let key = $zalsa::macro_if! { - if $needs_interner { - $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data) - } else { + if $needs_interner {{ + let (zalsa, zalsa_local) = $db.zalsas(); + $Configuration::intern_ingredient($db).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data) + }} else { $zalsa::AsId::as_id(&($($input_id),*)) } }; @@ -380,14 +390,17 @@ macro_rules! setup_tracked_fn { } $zalsa::attach($db, || { + let (zalsa, zalsa_local) = $db.zalsas(); let result = $zalsa::macro_if! { if $needs_interner { { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data); - $Configuration::fn_ingredient($db).fetch($db, key) + let key = $Configuration::intern_ingredient_(zalsa).intern_id(zalsa, zalsa_local, ($($input_id),*), |_, data| data); + $Configuration::fn_ingredient_($db, zalsa).fetch($db, zalsa, zalsa_local, key) } } else { - $Configuration::fn_ingredient($db).fetch($db, $zalsa::AsId::as_id(&($($input_id),*))) + { + $Configuration::fn_ingredient_($db, zalsa).fetch($db, zalsa, zalsa_local, $zalsa::AsId::as_id(&($($input_id),*))) + } } }; diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 5545a44fd..f92b1ac5f 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -282,8 +282,9 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - $Configuration::ingredient(db.as_dyn_database()).new_struct( - db.as_dyn_database(), + let (zalsa, zalsa_local) = db.zalsas(); + $Configuration::ingredient_(zalsa).new_struct( + zalsa,zalsa_local, ($($field_id,)*) ) } @@ -295,8 +296,8 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let db = db.as_dyn_database(); - let fields = $Configuration::ingredient(db).tracked_field(db, self, $relative_tracked_index); + let (zalsa, zalsa_local) = db.zalsas(); + let fields = $Configuration::ingredient_(zalsa).tracked_field(zalsa, zalsa_local, self, $relative_tracked_index); $crate::return_mode_expression!( $tracked_option, $tracked_ty, @@ -312,8 +313,8 @@ macro_rules! setup_tracked_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + $zalsa::Database, { - let db = db.as_dyn_database(); - let fields = $Configuration::ingredient(db).untracked_field(db, self); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).untracked_field(zalsa, self); $crate::return_mode_expression!( $untracked_option, $untracked_ty, @@ -335,7 +336,8 @@ macro_rules! setup_tracked_struct { $(for<$db_lt> $field_ty: std::fmt::Debug),* { $zalsa::with_attached_database(|db| { - let fields = $Configuration::ingredient(db).leak_fields(db, this); + let zalsa = db.zalsa(); + let fields = $Configuration::ingredient_(zalsa).leak_fields(zalsa, this); let mut f = f.debug_struct(stringify!($Struct)); let f = f.field("[salsa id]", &$zalsa::AsId::as_id(&this)); $( diff --git a/components/salsa-macros/src/db.rs b/components/salsa-macros/src/db.rs index 12ee48917..2c49604ad 100644 --- a/components/salsa-macros/src/db.rs +++ b/components/salsa-macros/src/db.rs @@ -110,18 +110,14 @@ impl DbMacro { let trait_name = &input.ident; input.items.push(parse_quote! { #[doc(hidden)] - fn zalsa_register_downcaster(&self) -> salsa::plumbing::DatabaseDownCaster; + fn zalsa_register_downcaster(&self) -> &salsa::plumbing::DatabaseDownCaster; }); - let comment = format!(" Downcast a [`dyn Database`] to a [`dyn {trait_name}`]"); + let comment = format!(" downcast `Self` to a [`dyn {trait_name}`]"); input.items.push(parse_quote! { #[doc = #comment] - /// - /// # Safety - /// - /// The input database must be of type `Self`. #[doc(hidden)] - unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #trait_name where Self: Sized; + fn downcast(&self) -> &dyn #trait_name where Self: Sized; }); Ok(()) } @@ -138,17 +134,17 @@ impl DbMacro { #[cold] #[inline(never)] #[doc(hidden)] - fn zalsa_register_downcaster(&self) -> salsa::plumbing::DatabaseDownCaster { - salsa::plumbing::views(self).add(::downcast) + fn zalsa_register_downcaster(&self) -> &salsa::plumbing::DatabaseDownCaster { + salsa::plumbing::views(self).add::(unsafe { + ::std::mem::transmute(::downcast as fn(_) -> _) + }) } }); input.items.push(parse_quote! { #[doc(hidden)] #[inline(always)] - unsafe fn downcast(db: &dyn salsa::plumbing::Database) -> &dyn #TraitPath where Self: Sized { - debug_assert_eq!(db.type_id(), ::core::any::TypeId::of::()); - // SAFETY: The input database must be of type `Self`. - unsafe { &*salsa::plumbing::transmute_data_ptr::(db) } + fn downcast(&self) -> &dyn #TraitPath where Self: Sized { + self } }); Ok(()) diff --git a/src/accumulator.rs b/src/accumulator.rs index 3b1358c60..4bd1280a7 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -102,7 +102,8 @@ impl Ingredient for IngredientImpl { unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, diff --git a/src/database.rs b/src/database.rs index b840398ff..30178b2da 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,13 +1,39 @@ -use std::any::Any; use std::borrow::Cow; +use std::ptr::NonNull; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; use crate::{Durability, Revision}; +#[derive(Copy, Clone)] +pub struct RawDatabase<'db> { + pub(crate) ptr: NonNull<()>, + _marker: std::marker::PhantomData<&'db dyn Database>, +} + +impl<'db, Db: Database + ?Sized> From<&'db Db> for RawDatabase<'db> { + #[inline] + fn from(db: &'db Db) -> Self { + RawDatabase { + ptr: NonNull::from(db).cast(), + _marker: std::marker::PhantomData, + } + } +} + +impl<'db, Db: Database + ?Sized> From<&'db mut Db> for RawDatabase<'db> { + #[inline] + fn from(db: &'db mut Db) -> Self { + RawDatabase { + ptr: NonNull::from(db).cast(), + _marker: std::marker::PhantomData, + } + } +} + /// The trait implemented by all Salsa databases. /// You can create your own subtraits of this trait using the `#[salsa::db]`(`crate::db`) procedural macro. -pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { +pub trait Database: Send + ZalsaDatabase + AsDynDatabase { /// Enforces current LRU limits, evicting entries if necessary. /// /// **WARNING:** Just like an ordinary write, this method triggers @@ -84,28 +110,27 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { #[cold] #[inline(never)] #[doc(hidden)] - fn zalsa_register_downcaster(&self) -> DatabaseDownCaster { + fn zalsa_register_downcaster(&self) -> &DatabaseDownCaster { self.zalsa().views().downcaster_for::() // The no-op downcaster is special cased in view caster construction. } #[doc(hidden)] #[inline(always)] - unsafe fn downcast(db: &dyn Database) -> &dyn Database + fn downcast(&self) -> &dyn Database where Self: Sized, { // No-op - db + self } } /// Upcast to a `dyn Database`. /// -/// Only required because upcasts not yet stabilized (*grr*). +/// Only required because upcasting does not work for unsized generic parameters. pub trait AsDynDatabase { fn as_dyn_database(&self) -> &dyn Database; - fn as_dyn_database_mut(&mut self) -> &mut dyn Database; } impl AsDynDatabase for T { @@ -113,30 +138,12 @@ impl AsDynDatabase for T { fn as_dyn_database(&self) -> &dyn Database { self } - - #[inline(always)] - fn as_dyn_database_mut(&mut self) -> &mut dyn Database { - self - } } pub fn current_revision(db: &Db) -> Revision { db.zalsa().current_revision() } -impl dyn Database { - /// Upcasts `self` to the given view. - /// - /// # Panics - /// - /// If the view has not been added to the database (see [`crate::views::Views`]). - #[track_caller] - pub fn as_view(&self) -> &DbView { - let views = self.zalsa().views(); - views.downcaster_for().downcast(self) - } -} - #[cfg(feature = "salsa_unstable")] pub use memory_usage::IngredientInfo; diff --git a/src/function.rs b/src/function.rs index ceb006feb..7642d4bab 100644 --- a/src/function.rs +++ b/src/function.rs @@ -10,6 +10,7 @@ use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy, ProvisionalStatus, }; +use crate::database::RawDatabase; use crate::function::delete::DeletedEntries; use crate::function::sync::{ClaimResult, SyncTable}; use crate::ingredient::{Ingredient, WaitForResult}; @@ -22,7 +23,7 @@ use crate::table::Table; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; -use crate::{Database, Id, Revision}; +use crate::{Id, Revision}; mod accumulated; mod backdate; @@ -68,10 +69,9 @@ pub trait Configuration: Any { /// This invokes user code in form of the `Eq` impl. fn values_equal<'db>(old_value: &Self::Output<'db>, new_value: &Self::Output<'db>) -> bool; - // FIXME: This should take a `&Zalsa` /// Convert from the id used internally to the value that execute is expecting. /// This is a no-op if the input to the function is a salsa struct. - fn id_to_input(db: &Self::DbView, key: Id) -> Self::Input<'_>; + fn id_to_input(zalsa: &Zalsa, key: Id) -> Self::Input<'_>; /// Returns the size of any heap allocations in the output value, in bytes. fn heap_size(_value: &Self::Output<'_>) -> usize { @@ -124,7 +124,7 @@ pub struct IngredientImpl { /// Used to find memos to throw out when we have too many memoized values. lru: lru::Lru, - /// A downcaster from `dyn Database` to `C::DbView`. + /// An downcaster to `C::DbView`. /// /// # Safety /// @@ -261,7 +261,8 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + db: RawDatabase<'_>, input: Id, revision: Revision, cycle_heads: &mut CycleHeads, @@ -370,12 +371,13 @@ where C::CYCLE_STRATEGY } - fn accumulated<'db>( + unsafe fn accumulated<'db>( &'db self, - db: &'db dyn Database, + db: RawDatabase<'db>, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { - let db = self.view_caster().downcast(db); + // SAFETY: The `db` belongs to the ingredient as per caller invariant + let db = unsafe { self.view_caster().downcast_unchecked(db) }; self.accumulated_map(db, key_index) } } diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 47fe09a84..a65804e64 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -4,7 +4,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::hash::FxHashSet; use crate::zalsa::ZalsaDatabase; use crate::zalsa_local::QueryOriginRef; -use crate::{AsDynDatabase, DatabaseKeyIndex, Id}; +use crate::{DatabaseKeyIndex, Id}; impl IngredientImpl where @@ -37,9 +37,8 @@ where let mut output = vec![]; // First ensure the result is up to date - self.fetch(db, key); + self.fetch(db, zalsa, zalsa_local, key); - let db = db.as_dyn_database(); let db_key = self.database_key_index(key); let mut visited: FxHashSet = FxHashSet::default(); let mut stack: Vec = vec![db_key]; @@ -54,7 +53,9 @@ where let ingredient = zalsa.lookup_ingredient(k.ingredient_index()); // Extend `output` with any values accumulated by `k`. - let (accumulated_map, input) = ingredient.accumulated(db, k.key_index()); + // SAFETY: `db` owns the `ingredient` + let (accumulated_map, input) = + unsafe { ingredient.accumulated(db.into(), k.key_index()) }; if let Some(accumulated_map) = accumulated_map { accumulated_map.extend_with_accumulated(accumulator.index(), &mut output); } diff --git a/src/function/execute.rs b/src/function/execute.rs index 9013ee7fe..2690d1a5c 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -4,7 +4,7 @@ use crate::function::{Configuration, IngredientImpl}; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions}; -use crate::{Event, EventKind, Id, Revision}; +use crate::{Event, EventKind, Id}; impl IngredientImpl where @@ -41,16 +41,11 @@ where let (new_value, mut revisions) = match C::CYCLE_STRATEGY { CycleRecoveryStrategy::Panic => { - Self::execute_query(db, active_query, opt_old_memo, zalsa.current_revision(), id) + Self::execute_query(db, zalsa, active_query, opt_old_memo, id) } CycleRecoveryStrategy::FallbackImmediate => { - let (mut new_value, mut revisions) = Self::execute_query( - db, - active_query, - opt_old_memo, - zalsa.current_revision(), - id, - ); + let (mut new_value, mut revisions) = + Self::execute_query(db, zalsa, active_query, opt_old_memo, id); if let Some(cycle_heads) = revisions.cycle_heads_mut() { // Did the new result we got depend on our own provisional value, in a cycle? @@ -77,7 +72,7 @@ where let active_query = db .zalsa_local() .push_query(database_key_index, IterationCount::initial()); - new_value = C::cycle_initial(db, C::id_to_input(db, id)); + new_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); revisions = active_query.pop(); // We need to set `cycle_heads` and `verified_final` because it needs to propagate to the callers. // When verifying this, we will see we have fallback and mark ourselves verified. @@ -136,13 +131,8 @@ where let mut opt_last_provisional: Option<&Memo<'db, C>> = None; loop { let previous_memo = opt_last_provisional.or(opt_old_memo); - let (mut new_value, mut revisions) = Self::execute_query( - db, - active_query, - previous_memo, - zalsa.current_revision(), - id, - ); + let (mut new_value, mut revisions) = + Self::execute_query(db, zalsa, active_query, previous_memo, id); // Did the new result we got depend on our own provisional value, in a cycle? if let Some(cycle_heads) = revisions @@ -192,7 +182,7 @@ where db, &new_value, iteration_count.as_u32(), - C::id_to_input(db, id), + C::id_to_input(zalsa, id), ) { crate::CycleRecoveryAction::Iterate => {} crate::CycleRecoveryAction::Fallback(fallback_value) => { @@ -258,9 +248,9 @@ where #[inline] fn execute_query<'db>( db: &'db C::DbView, + zalsa: &'db Zalsa, active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo<'db, C>>, - current_revision: Revision, id: Id, ) -> (C::Output<'db>, QueryRevisions) { if let Some(old_memo) = opt_old_memo { @@ -275,14 +265,16 @@ where // * ensure that tracked struct created during the previous iteration // (and are owned by the query) are alive even if the query in this iteration no longer creates them. // * ensure the final returned memo depends on all inputs from all iterations. - if old_memo.may_be_provisional() && old_memo.verified_at.load() == current_revision { + if old_memo.may_be_provisional() + && old_memo.verified_at.load() == zalsa.current_revision() + { active_query.seed_iteration(&old_memo.revisions); } } // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! - let new_value = C::execute(db, C::id_to_input(db, id)); + let new_value = C::execute(db, C::id_to_input(zalsa, id)); (new_value, active_query.pop()) } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 252bba124..d6de9d9cb 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -2,7 +2,7 @@ use crate::cycle::{CycleHeads, CycleRecoveryStrategy, IterationCount}; use crate::function::memo::Memo; use crate::function::sync::ClaimResult; use crate::function::{Configuration, IngredientImpl, VerifyResult}; -use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; +use crate::zalsa::{MemoIngredientIndex, Zalsa}; use crate::zalsa_local::{QueryRevisions, ZalsaLocal}; use crate::Id; @@ -10,8 +10,13 @@ impl IngredientImpl where C: Configuration, { - pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> { - let (zalsa, zalsa_local) = db.zalsas(); + pub fn fetch<'db>( + &'db self, + db: &'db C::DbView, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, + id: Id, + ) -> &'db C::Output<'db> { zalsa.unwind_if_revision_cancelled(zalsa_local); let database_key_index = self.database_key_index(id); @@ -175,7 +180,7 @@ where inserting and returning fixpoint initial value" ); let revisions = QueryRevisions::fixpoint_initial(database_key_index); - let initial_value = C::cycle_initial(db, C::id_to_input(db, id)); + let initial_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); Some(self.insert_memo( zalsa, id, @@ -189,7 +194,7 @@ where ); let active_query = zalsa_local.push_query(database_key_index, IterationCount::initial()); - let fallback_value = C::cycle_initial(db, C::id_to_input(db, id)); + let fallback_value = C::cycle_initial(db, C::id_to_input(zalsa, id)); let mut revisions = active_query.pop(); revisions.set_cycle_heads(CycleHeads::initial(database_key_index)); // We need this for `cycle_heads()` to work. We will unset this in the outer `execute()`. diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 7df0a41fe..20e82d1fa 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -7,7 +7,7 @@ use crate::key::DatabaseKeyIndex; use crate::sync::atomic::Ordering; use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{QueryEdgeKind, QueryOriginRef, ZalsaLocal}; -use crate::{AsDynDatabase as _, Id, Revision}; +use crate::{Id, Revision}; /// Result of memo validation. pub enum VerifyResult { @@ -434,8 +434,6 @@ where return VerifyResult::Changed; } - let dyn_db = db.as_dyn_database(); - let mut inputs = InputAccumulatedValues::Empty; // Fully tracked inputs? Iterate over the inputs and check them, one by one. // @@ -447,7 +445,7 @@ where match edge.kind() { QueryEdgeKind::Input(dependency_index) => { match dependency_index.maybe_changed_after( - dyn_db, + db.into(), zalsa, old_memo.verified_at.load(), cycle_heads, diff --git a/src/function/memo.rs b/src/function/memo.rs index d6a872b69..a478b1d46 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -497,7 +497,7 @@ mod _memory_usage { unimplemented!() } - fn id_to_input(_: &Self::DbView, _: Id) -> Self::Input<'_> { + fn id_to_input(_: &Zalsa, _: Id) -> Self::Input<'_> { unimplemented!() } diff --git a/src/ingredient.rs b/src/ingredient.rs index fb567948b..12b8ebcba 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -5,6 +5,7 @@ use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues use crate::cycle::{ empty_cycle_heads, CycleHeads, CycleRecoveryStrategy, IterationCount, ProvisionalStatus, }; +use crate::database::RawDatabase; use crate::function::VerifyResult; use crate::runtime::Running; use crate::sync::Arc; @@ -12,7 +13,7 @@ use crate::table::memo::MemoTableTypes; use crate::table::Table; use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; -use crate::{Database, DatabaseKeyIndex, Id, Revision}; +use crate::{DatabaseKeyIndex, Id, Revision}; /// A "jar" is a group of ingredients that are added atomically. /// @@ -45,9 +46,10 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// # Safety /// /// The passed in database needs to be the same one that the ingredient was created with. - unsafe fn maybe_changed_after<'db>( - &'db self, - db: &'db dyn Database, + unsafe fn maybe_changed_after( + &self, + zalsa: &crate::zalsa::Zalsa, + db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, cycle_heads: &mut CycleHeads, @@ -159,9 +161,13 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// What values were accumulated during the creation of the value at `key_index` /// (if any). - fn accumulated<'db>( + /// + /// # Safety + /// + /// The passed in database needs to be the same one that the ingredient was created with. + unsafe fn accumulated<'db>( &'db self, - db: &'db dyn Database, + db: RawDatabase<'db>, key_index: Id, ) -> (Option<&'db AccumulatedMap>, InputAccumulatedValues) { let _ = (db, key_index); @@ -171,7 +177,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// Returns memory usage information about any instances of the ingredient, /// if applicable. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, _db: &dyn Database) -> Option> { + fn memory_usage(&self, _db: &dyn crate::Database) -> Option> { None } } @@ -179,7 +185,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { impl dyn Ingredient { /// Equivalent to the `downcast` method on `Any`. /// - /// Because we do not have dyn-upcasting support, we need this workaround. + /// Because we do not have dyn-downcasting support, we need this workaround. pub fn assert_type(&self) -> &T { assert_eq!( self.type_id(), @@ -195,7 +201,7 @@ impl dyn Ingredient { /// Equivalent to the `downcast` methods on `Any`. /// - /// Because we do not have dyn-upcasting support, we need this workaround. + /// Because we do not have dyn-downcasting support, we need this workaround. /// /// # Safety /// @@ -214,7 +220,7 @@ impl dyn Ingredient { /// Equivalent to the `downcast` method on `Any`. /// - /// Because we do not have dyn-upcasting support, we need this workaround. + /// Because we do not have dyn-downcasting support, we need this workaround. pub fn assert_type_mut(&mut self) -> &mut T { assert_eq!( Any::type_id(self), diff --git a/src/input.rs b/src/input.rs index fe25d9b91..af6648e73 100644 --- a/src/input.rs +++ b/src/input.rs @@ -19,7 +19,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Durability, Id, Revision, Runtime}; +use crate::{zalsa_local, Durability, Id, Revision, Runtime}; pub trait Configuration: Any { const DEBUG_NAME: &'static str; @@ -104,13 +104,12 @@ impl IngredientImpl { pub fn new_input( &self, - db: &dyn Database, + zalsa: &Zalsa, + zalsa_local: &zalsa_local::ZalsaLocal, fields: C::Fields, revisions: C::Revisions, durabilities: C::Durabilities, ) -> C::Struct { - let (zalsa, zalsa_local) = db.zalsas(); - let id = self.singleton.with_scope(|| { zalsa_local.allocate(zalsa, self.ingredient_index, |_| Value:: { fields, @@ -177,11 +176,11 @@ impl IngredientImpl { /// The caller is responsible for selecting the appropriate element. pub fn field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db zalsa_local::ZalsaLocal, id: C::Struct, field_index: usize, ) -> &'db C::Fields { - let (zalsa, zalsa_local) = db.zalsas(); let field_ingredient_index = self.ingredient_index.successor(field_index); let id = id.as_id(); let value = Self::data(zalsa, id); @@ -197,17 +196,13 @@ impl IngredientImpl { #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the input struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } /// Peek at the field values without recording any read dependency. /// Used for debug printouts. - pub fn leak_fields<'db>(&'db self, db: &'db dyn Database, id: C::Struct) -> &'db C::Fields { - let zalsa = db.zalsa(); + pub fn leak_fields<'db>(&'db self, zalsa: &'db Zalsa, id: C::Struct) -> &'db C::Fields { let id = id.as_id(); let value = Self::data(zalsa, id); &value.fields @@ -225,7 +220,8 @@ impl Ingredient for IngredientImpl { unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, @@ -249,9 +245,9 @@ impl Ingredient for IngredientImpl { /// Returns memory usage information about any inputs. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn crate::Database) -> Option> { let memory_usage = self - .entries(db) + .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) diff --git a/src/input/input_field.rs b/src/input/input_field.rs index f0e4856c8..82ed9889d 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -8,7 +8,7 @@ use crate::input::{Configuration, IngredientImpl, Value}; use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::zalsa::IngredientIndex; -use crate::{Database, Id, Revision}; +use crate::{Id, Revision}; /// Ingredient used to represent the fields of a `#[salsa::input]`. /// @@ -52,12 +52,12 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, input: Id, revision: Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); let value = >::data(zalsa, input); VerifyResult::changed_if(value.revisions[self.field_index] > revision) } diff --git a/src/interned.rs b/src/interned.rs index 812512ff6..e3aecd309 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -21,7 +21,7 @@ use crate::sync::{Arc, Mutex, OnceLock}; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::Slot; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, DatabaseKeyIndex, Event, EventKind, Id, Revision}; +use crate::{DatabaseKeyIndex, Event, EventKind, Id, Revision}; /// Trait that defines the key properties of an interned struct. /// @@ -296,7 +296,8 @@ where /// the database ends up trying to intern or allocate a new value. pub fn intern<'db, Key>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, key: Key, assemble: impl FnOnce(Id, Key) -> C::Fields<'db>, ) -> C::Struct<'db> @@ -304,7 +305,7 @@ where Key: Hash, C::Fields<'db>: HashEqLike, { - FromId::from_id(self.intern_id(db, key, assemble)) + FromId::from_id(self.intern_id(zalsa, zalsa_local, key, assemble)) } /// Intern data to a unique reference. @@ -319,7 +320,8 @@ where /// the database ends up trying to intern or allocate a new value. pub fn intern_id<'db, Key>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, key: Key, assemble: impl FnOnce(Id, Key) -> C::Fields<'db>, ) -> crate::Id @@ -331,8 +333,6 @@ where // so instead we go with this and transmute the lifetime in the `eq` closure C::Fields<'db>: HashEqLike, { - let (zalsa, zalsa_local) = db.zalsas(); - // Record the current revision as active. let current_revision = zalsa.current_revision(); self.revision_queue.record(current_revision); @@ -735,8 +735,7 @@ where } /// Lookup the data for an interned value based on its ID. - pub fn data<'db>(&'db self, db: &'db dyn Database, id: Id) -> &'db C::Fields<'db> { - let zalsa = db.zalsa(); + pub fn data<'db>(&'db self, zalsa: &'db Zalsa, id: Id) -> &'db C::Fields<'db> { let value = zalsa.table().get::>(id); debug_assert!( @@ -761,12 +760,12 @@ where /// Lookup the fields from an interned struct. /// /// Note that this is not "leaking" since no dependency edge is required. - pub fn fields<'db>(&'db self, db: &'db dyn Database, s: C::Struct<'db>) -> &'db C::Fields<'db> { - self.data(db, AsId::as_id(&s)) + pub fn fields<'db>(&'db self, zalsa: &'db Zalsa, s: C::Struct<'db>) -> &'db C::Fields<'db> { + self.data(zalsa, AsId::as_id(&s)) } - pub fn reset(&mut self, db: &mut dyn Database) { - _ = db.zalsa_mut(); + pub fn reset(&mut self, zalsa_mut: &mut Zalsa) { + _ = zalsa_mut; for shard in self.shards.iter() { // We can clear the key maps now that we have cancelled all other handles. @@ -776,11 +775,8 @@ where #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the interned struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } } @@ -798,13 +794,12 @@ where unsafe fn maybe_changed_after( &self, - db: &dyn Database, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); - // Record the current revision as active. let current_revision = zalsa.current_revision(); self.revision_queue.record(current_revision); @@ -852,7 +847,7 @@ where /// Returns memory usage information about any interned values. #[cfg(all(not(feature = "shuttle"), feature = "salsa_unstable"))] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn crate::Database) -> Option> { use parking_lot::lock_api::RawMutex; for shard in self.shards.iter() { @@ -861,7 +856,7 @@ where } let memory_usage = self - .entries(db) + .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. Additionally, we are holding the locks for all shards. .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) diff --git a/src/key.rs b/src/key.rs index 5883ef9cb..80904e978 100644 --- a/src/key.rs +++ b/src/key.rs @@ -3,7 +3,7 @@ use core::fmt; use crate::cycle::CycleHeads; use crate::function::VerifyResult; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Id}; +use crate::Id; // ANCHOR: DatabaseKeyIndex /// An integer that uniquely identifies a particular query instance within the @@ -36,16 +36,18 @@ impl DatabaseKeyIndex { pub(crate) fn maybe_changed_after( &self, - db: &dyn Database, + db: crate::database::RawDatabase<'_>, zalsa: &Zalsa, last_verified_at: crate::Revision, cycle_heads: &mut CycleHeads, ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient unsafe { + // here, `db` has to be either the correct type already, or a subtype (as far as trait + // hierarchy is concerned) zalsa .lookup_ingredient(self.ingredient_index()) - .maybe_changed_after(db, self.key_index(), last_verified_at, cycle_heads) + .maybe_changed_after(zalsa, db, self.key_index(), last_verified_at, cycle_heads) } } diff --git a/src/lib.rs b/src/lib.rs index 7bc94eec4..2600d9a33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,7 +51,7 @@ pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; pub use self::cancelled::Cancelled; pub use self::cycle::CycleRecoveryAction; -pub use self::database::{AsDynDatabase, Database}; +pub use self::database::Database; pub use self::database_impl::DatabaseImpl; pub use self::durability::Durability; pub use self::event::{Event, EventKind}; diff --git a/src/parallel.rs b/src/parallel.rs index 1d2504b77..8a0bde655 100644 --- a/src/parallel.rs +++ b/src/parallel.rs @@ -1,44 +1,91 @@ use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; -use crate::Database; +use crate::{database::RawDatabase, views::DatabaseDownCaster, Database}; pub fn par_map(db: &Db, inputs: impl IntoParallelIterator, op: F) -> C where - Db: Database + ?Sized, + Db: Database + ?Sized + Send, F: Fn(&Db, T) -> R + Sync + Send, T: Send, R: Send + Sync, C: FromParallelIterator, { + let views = db.zalsa().views(); + let caster = &views.downcaster_for::(); + let db_caster = &views.downcaster_for::(); inputs .into_par_iter() - .map_with(DbForkOnClone(db.fork_db()), |db, element| { - op(db.0.as_view(), element) - }) + .map_with( + DbForkOnClone(db.fork_db(), caster, db_caster), + |db, element| op(db.as_view(), element), + ) .collect() } -struct DbForkOnClone(Box); +struct DbForkOnClone<'views, Db: Database + ?Sized>( + RawDatabase<'static>, + &'views DatabaseDownCaster, + &'views DatabaseDownCaster, +); -impl Clone for DbForkOnClone { +// SAFETY: `T: Send` -> `&own T: Send`, `DbForkOnClone` is an owning pointer +unsafe impl Send for DbForkOnClone<'_, Db> {} + +impl DbForkOnClone<'_, Db> { + fn as_view(&self) -> &Db { + // SAFETY: The downcaster ensures that the pointer is valid for the lifetime of the view. + unsafe { self.1.downcast_unchecked(self.0) } + } +} + +impl Drop for DbForkOnClone<'_, Db> { + fn drop(&mut self) { + // SAFETY: `caster` is derived from a `db` fitting for our database clone + let db = unsafe { self.1.downcast_mut_unchecked(self.0) }; + // SAFETY: `db` has been box allocated and leaked by `fork_db` + _ = unsafe { Box::from_raw(db) }; + } +} + +impl Clone for DbForkOnClone<'_, Db> { fn clone(&self) -> Self { - DbForkOnClone(self.0.fork_db()) + DbForkOnClone( + // SAFETY: `caster` is derived from a `db` fitting for our database clone + unsafe { self.2.downcast_unchecked(self.0) }.fork_db(), + self.1, + self.2, + ) } } -pub fn join(db: &Db, a: A, b: B) -> (RA, RB) +pub fn join(db: &Db, a: A, b: B) -> (RA, RB) where A: FnOnce(&Db) -> RA + Send, B: FnOnce(&Db) -> RB + Send, RA: Send, RB: Send, { + #[derive(Copy, Clone)] + struct AssertSend(T); + // SAFETY: We send owning pointers over, which are Send, given the `Db` type parameter above is Send + unsafe impl Send for AssertSend {} + + let caster = &db.zalsa().views().downcaster_for::(); // we need to fork eagerly, as `rayon::join_context` gives us no option to tell whether we get // moved to another thread before the closure is executed - let db_a = db.fork_db(); - let db_b = db.fork_db(); - rayon::join( - move || a(db_a.as_view::()), - move || b(db_b.as_view::()), - ) + let db_a = AssertSend(db.fork_db()); + let db_b = AssertSend(db.fork_db()); + let res = rayon::join( + // SAFETY: `caster` is derived from a `db` fitting for our database clone + move || a(unsafe { caster.downcast_unchecked({ db_a }.0) }), + // SAFETY: `caster` is derived from a `db` fitting for our database clone + move || b(unsafe { caster.downcast_unchecked({ db_b }.0) }), + ); + + // SAFETY: `db` has been box allocated and leaked by `fork_db` + // FIXME: Clean this mess up, RAII + _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_a.0)) }; + // SAFETY: `db` has been box allocated and leaked by `fork_db` + _ = unsafe { Box::from_raw(caster.downcast_mut_unchecked(db_b.0)) }; + res } diff --git a/src/storage.rs b/src/storage.rs index a8c2abec0..f63981e4f 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use std::panic::RefUnwindSafe; +use crate::database::RawDatabase; use crate::sync::{Arc, Condvar, Mutex}; use crate::zalsa::{ErasedJar, HasJar, Zalsa, ZalsaDatabase}; use crate::zalsa_local::{self, ZalsaLocal}; @@ -245,8 +246,8 @@ unsafe impl ZalsaDatabase for T { } #[inline(always)] - fn fork_db(&self) -> Box { - Box::new(self.clone()) + fn fork_db(&self) -> RawDatabase<'static> { + Box::leak(Box::new(self.clone())).into() } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index f6f4ea440..ec240ebcb 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -23,7 +23,7 @@ use crate::sync::Arc; use crate::table::memo::{MemoTable, MemoTableTypes, MemoTableWithTypesMut}; use crate::table::{Slot, Table}; use crate::zalsa::{IngredientIndex, Zalsa}; -use crate::{Database, Durability, Event, EventKind, Id, Revision}; +use crate::{Durability, Event, EventKind, Id, Revision}; pub mod tracked_field; @@ -375,11 +375,10 @@ where pub fn new_struct<'db>( &'db self, - db: &'db dyn Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, mut fields: C::Fields<'db>, ) -> C::Struct<'db> { - let (zalsa, zalsa_local) = db.zalsas(); - let identity_hash = IdentityHash { ingredient_index: self.ingredient_index, hash: crate::hash::hash(&C::untracked_fields(&fields)), @@ -734,11 +733,11 @@ where /// Used for debugging. pub fn leak_fields<'db>( &'db self, - db: &'db dyn Database, + zalsa: &'db Zalsa, s: C::Struct<'db>, ) -> &'db C::Fields<'db> { let id = AsId::as_id(&s); - let data = Self::data(db.zalsa().table(), id); + let data = Self::data(zalsa.table(), id); data.fields() } @@ -748,11 +747,11 @@ where /// The caller is responsible for selecting the appropriate element. pub fn tracked_field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, + zalsa_local: &'db ZalsaLocal, s: C::Struct<'db>, relative_tracked_index: usize, ) -> &'db C::Fields<'db> { - let (zalsa, zalsa_local) = db.zalsas(); let id = AsId::as_id(&s); let field_ingredient_index = self.ingredient_index.successor(relative_tracked_index); let data = Self::data(zalsa.table(), id); @@ -776,10 +775,9 @@ where /// The caller is responsible for selecting the appropriate element. pub fn untracked_field<'db>( &'db self, - db: &'db dyn crate::Database, + zalsa: &'db Zalsa, s: C::Struct<'db>, ) -> &'db C::Fields<'db> { - let zalsa = db.zalsa(); let id = AsId::as_id(&s); let data = Self::data(zalsa.table(), id); @@ -794,11 +792,8 @@ where #[cfg(feature = "salsa_unstable")] /// Returns all data corresponding to the tracked struct. - pub fn entries<'db>( - &'db self, - db: &'db dyn crate::Database, - ) -> impl Iterator> { - db.zalsa().table().slots_of::>() + pub fn entries<'db>(&'db self, zalsa: &'db Zalsa) -> impl Iterator> { + zalsa.table().slots_of::>() } } @@ -816,7 +811,8 @@ where unsafe fn maybe_changed_after( &self, - _db: &dyn Database, + _zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, _input: Id, _revision: Revision, _cycle_heads: &mut CycleHeads, @@ -863,9 +859,9 @@ where /// Returns memory usage information about any tracked structs. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn crate::Database) -> Option> { let memory_usage = self - .entries(db) + .entries(db.zalsa()) // SAFETY: The memo table belongs to a value that we allocated, so it // has the correct type. .map(|value| unsafe { value.memory_usage(&self.memo_table_types) }) diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index ad3e871e8..587e473fa 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -7,7 +7,7 @@ use crate::sync::Arc; use crate::table::memo::MemoTableTypes; use crate::tracked_struct::{Configuration, Value}; use crate::zalsa::IngredientIndex; -use crate::{Database, Id}; +use crate::Id; /// Created for each tracked struct. /// @@ -55,14 +55,14 @@ where self.ingredient_index } - unsafe fn maybe_changed_after<'db>( - &'db self, - db: &'db dyn Database, + unsafe fn maybe_changed_after( + &self, + zalsa: &crate::zalsa::Zalsa, + _db: crate::database::RawDatabase<'_>, input: Id, revision: crate::Revision, _cycle_heads: &mut CycleHeads, ) -> VerifyResult { - let zalsa = db.zalsa(); let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; VerifyResult::changed_if(field_changed_at > revision) diff --git a/src/views.rs b/src/views.rs index 01a0a2de5..d449779c3 100644 --- a/src/views.rs +++ b/src/views.rs @@ -1,10 +1,15 @@ -use std::any::{Any, TypeId}; +use std::{ + any::{Any, TypeId}, + marker::PhantomData, + mem, + ptr::NonNull, +}; -use crate::Database; +use crate::{database::RawDatabase, Database}; /// A `Views` struct is associated with some specific database type /// (a `DatabaseImpl` for some existential `U`). It contains functions -/// to downcast from `dyn Database` to `dyn DbView` for various traits `DbView` via this specific +/// to downcast to `dyn DbView` for various traits `DbView` via this specific /// database type. /// None of these types are known at compilation time, they are all checked /// dynamically through `TypeId` magic. @@ -13,6 +18,7 @@ pub struct Views { view_casters: boxcar::Vec, } +#[derive(Copy, Clone)] struct ViewCaster { /// The id of the target type `dyn DbView` that we can cast to. target_type_id: TypeId, @@ -20,50 +26,69 @@ struct ViewCaster { /// The name of the target type `dyn DbView` that we can cast to. type_name: &'static str, - /// Type-erased function pointer that downcasts from `dyn Database` to `dyn DbView`. + /// Type-erased function pointer that downcasts to `dyn DbView`. cast: ErasedDatabaseDownCasterSig, } impl ViewCaster { - fn new(func: unsafe fn(&dyn Database) -> &DbView) -> ViewCaster { + fn new(func: DatabaseDownCasterSig) -> ViewCaster { ViewCaster { target_type_id: TypeId::of::(), type_name: std::any::type_name::(), // SAFETY: We are type erasing for storage, taking care of unerasing before we call // the function pointer. cast: unsafe { - std::mem::transmute::, ErasedDatabaseDownCasterSig>( - func, - ) + mem::transmute::, ErasedDatabaseDownCasterSig>(func) }, } } } -type ErasedDatabaseDownCasterSig = unsafe fn(&dyn Database) -> *const (); -type DatabaseDownCasterSig = unsafe fn(&dyn Database) -> &DbView; +type ErasedDatabaseDownCasterSig = unsafe fn(RawDatabase<'_>) -> NonNull<()>; +type DatabaseDownCasterSig = unsafe fn(RawDatabase<'_>) -> NonNull; -pub struct DatabaseDownCaster(TypeId, DatabaseDownCasterSig); +#[repr(transparent)] +pub struct DatabaseDownCaster(ViewCaster, PhantomData DbView>); -impl DatabaseDownCaster { - pub fn downcast<'db>(&self, db: &'db dyn Database) -> &'db DbView { - assert_eq!( - self.0, - db.type_id(), - "Database type does not match the expected type for this `Views` instance" - ); - // SAFETY: We've asserted that the database is correct. - unsafe { (self.1)(db) } +impl Copy for DatabaseDownCaster {} +impl Clone for DatabaseDownCaster { + fn clone(&self) -> Self { + *self } +} +impl DatabaseDownCaster { + /// Downcast `db` to `DbView`. + /// + /// # Safety + /// + /// The caller must ensure that `db` is of the correct type. + #[inline] + pub unsafe fn downcast_unchecked<'db>(&self, db: RawDatabase<'db>) -> &'db DbView { + // SAFETY: The caller must ensure that `db` is of the correct type. + // The returned pointer is live for `'db` due to construction of the downcaster functions. + unsafe { (self.unerased_downcaster())(db).as_ref() } + } /// Downcast `db` to `DbView`. /// /// # Safety /// /// The caller must ensure that `db` is of the correct type. - pub unsafe fn downcast_unchecked<'db>(&self, db: &'db dyn Database) -> &'db DbView { + #[inline] + pub unsafe fn downcast_mut_unchecked<'db>(&self, db: RawDatabase<'db>) -> &'db mut DbView { // SAFETY: The caller must ensure that `db` is of the correct type. - unsafe { (self.1)(db) } + // The returned pointer is live for `'db` due to construction of the downcaster functions. + unsafe { (self.unerased_downcaster())(db).as_mut() } + } + + #[inline] + fn unerased_downcaster(&self) -> DatabaseDownCasterSig { + // SAFETY: The type-erased function pointer is guaranteed to be ABI compatible for `DbView` + unsafe { + mem::transmute::>( + self.0.cast, + ) + } } } @@ -71,58 +96,63 @@ impl Views { pub(crate) fn new() -> Self { let source_type_id = TypeId::of::(); let view_casters = boxcar::Vec::new(); - // special case the no-op transformation, that way we skip out on reconstructing the wide pointer - view_casters.push(ViewCaster::new::(|db| db)); + view_casters.push(ViewCaster::new::(|db| db.ptr.cast::())); Self { source_type_id, view_casters, } } - /// Add a new downcaster from `dyn Database` to `dyn DbView`. - pub fn add( + /// Add a new downcaster to `dyn DbView`. + pub fn add( &self, - func: DatabaseDownCasterSig, - ) -> DatabaseDownCaster { - if let Some(view) = self.try_downcaster_for() { - return view; + func: fn(NonNull) -> NonNull, + ) -> &DatabaseDownCaster { + assert_eq!(self.source_type_id, TypeId::of::()); + let target_type_id = TypeId::of::(); + if let Some((_, caster)) = self + .view_casters + .iter() + .find(|(_, u)| u.target_type_id == target_type_id) + { + // SAFETY: The type-erased function pointer is guaranteed to be valid for `DbView` + return unsafe { &*(&raw const *caster).cast::>() }; } - self.view_casters.push(ViewCaster::new::(func)); - DatabaseDownCaster(self.source_type_id, func) + // SAFETY: We are type erasing the function pointer for storage, and we will unerase it + // before we call it. + let caster = unsafe { + mem::transmute::) -> NonNull, DatabaseDownCasterSig>( + func, + ) + }; + let caster = ViewCaster::new::(caster); + let idx = self.view_casters.push(caster); + // SAFETY: The type-erased function pointer is guaranteed to be valid for `DbView` + unsafe { &*(&raw const self.view_casters[idx]).cast::>() } } - /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`. + /// Retrieve an downcaster function to `dyn DbView`. /// /// # Panics /// - /// If the underlying type of `db` is not the same as the database type this upcasts was created for. - pub fn downcaster_for(&self) -> DatabaseDownCaster { - self.try_downcaster_for().unwrap_or_else(|| { - panic!( - "No downcaster registered for type `{}` in `Views`", - std::any::type_name::(), - ) - }) - } - - /// Retrieve an downcaster function from `dyn Database` to `dyn DbView`, if it exists. - #[inline] - pub fn try_downcaster_for(&self) -> Option> { + /// If the underlying type of `db` is not the same as the database type this downcasts was created for. + pub fn downcaster_for(&self) -> &DatabaseDownCaster { let view_type_id = TypeId::of::(); for (_, view) in self.view_casters.iter() { if view.target_type_id == view_type_id { // SAFETY: We are unerasing the type erased function pointer having made sure the - // `TypeId` matches. - return Some(DatabaseDownCaster(self.source_type_id, unsafe { - std::mem::transmute::>( - view.cast, - ) - })); + // TypeId matches. + return unsafe { + &*((view as *const ViewCaster).cast::>()) + }; } } - None + panic!( + "No downcaster registered for type `{}` in `Views`", + std::any::type_name::(), + ); } } diff --git a/src/zalsa.rs b/src/zalsa.rs index 41ece3cae..1cc6ba5f5 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -5,6 +5,7 @@ use std::panic::RefUnwindSafe; use hashbrown::HashMap; use rustc_hash::FxHashMap; +use crate::database::RawDatabase; use crate::hash::TypeIdHasher; use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::SalsaStructInDb; @@ -52,7 +53,7 @@ pub unsafe trait ZalsaDatabase: Any { /// Clone the database. #[doc(hidden)] - fn fork_db(&self) -> Box; + fn fork_db(&self) -> RawDatabase<'static>; } pub fn views(db: &Db) -> &Views { diff --git a/tests/debug_db_contents.rs b/tests/debug_db_contents.rs index 6ab8b212e..a253d8869 100644 --- a/tests/debug_db_contents.rs +++ b/tests/debug_db_contents.rs @@ -22,14 +22,15 @@ fn tracked_fn(db: &dyn salsa::Database, input: InputStruct) -> TrackedStruct<'_> #[test] fn execute() { + use salsa::plumbing::ZalsaDatabase; let db = salsa::DatabaseImpl::new(); let _ = InternedStruct::new(&db, "Salsa".to_string()); let _ = InternedStruct::new(&db, "Salsa2".to_string()); // test interned structs - let interned = InternedStruct::ingredient(&db) - .entries(&db) + let interned = InternedStruct::ingredient(db.zalsa()) + .entries(db.zalsa()) .collect::>(); assert_eq!(interned.len(), 2); @@ -40,7 +41,7 @@ fn execute() { let input = InputStruct::new(&db, 22); let inputs = InputStruct::ingredient(&db) - .entries(&db) + .entries(db.zalsa()) .collect::>(); assert_eq!(inputs.len(), 1); @@ -50,7 +51,7 @@ fn execute() { let computed = tracked_fn(&db, input).field(&db); assert_eq!(computed, 44); let tracked = TrackedStruct::ingredient(&db) - .entries(&db) + .entries(db.zalsa()) .collect::>(); assert_eq!(tracked.len(), 1); diff --git a/tests/interned-structs.rs b/tests/interned-structs.rs index 931b1ab67..a9db074c4 100644 --- a/tests/interned-structs.rs +++ b/tests/interned-structs.rs @@ -132,13 +132,13 @@ fn interning_boxed() { #[test] fn interned_structs_have_public_ingredients() { - use salsa::plumbing::AsId; + use salsa::plumbing::{AsId, ZalsaDatabase}; let db = salsa::DatabaseImpl::new(); let s = InternedString::new(&db, String::from("Hello, world!")); let underlying_id = s.0; - let data = InternedString::ingredient(&db).data(&db, underlying_id.as_id()); + let data = InternedString::ingredient(db.zalsa()).data(db.zalsa(), underlying_id.as_id()); assert_eq!(data, &(String::from("Hello, world!"),)); } diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 55eb8c06f..3443f3ac2 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -181,7 +181,8 @@ const _: () = { String: zalsa_::interned::HashEqLike, { Configuration_::ingredient(db).intern( - db.as_dyn_database(), + db.zalsa(), + db.zalsa_local(), StructKey::<'db>(data, std::marker::PhantomData::default()), |id, data| { StructData( @@ -195,20 +196,20 @@ const _: () = { where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.0)) } fn other(self, db: &'db Db_) -> InternedString<'db> where Db_: ?Sized + zalsa_::Database, { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), self); std::clone::Clone::clone((&fields.1)) } #[doc = r" Default debug formatting for this struct (may be useful if you define your own `Debug` impl)"] pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { zalsa_::with_attached_database(|db| { - let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), this); + let fields = Configuration_::ingredient(db).fields(db.zalsa(), this); let mut f = f.debug_struct("InternedString"); let f = f.field("data", &fields.0); let f = f.field("other", &fields.1);