Skip to content

Commit

Permalink
Add Callable to Variant and support marshalling to/from lua
Browse files Browse the repository at this point in the history
  • Loading branch information
mattkleiny authored Aug 2, 2024
1 parent 29800ad commit 9256918
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 59 deletions.
1 change: 0 additions & 1 deletion .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
FROM mcr.microsoft.com/devcontainers/rust:latest

RUN apt-get update && apt-get -y install libgl1-mesa-dev libasound2-dev
RUN rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt
13 changes: 0 additions & 13 deletions core/common/src/abstractions/assets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,6 @@ pub struct AssetContext {
dependencies: Graph<AssetId>,
}

impl AssetContext {
/// Builds a new asset context from a root asset ID.
fn from_asset_id(asset_id: AssetId) -> Self {
let mut dependencies = Graph::default();
let current_node = dependencies.add_node(asset_id);

Self {
current_node,
dependencies,
}
}
}

/// Represents an asset that can be loaded and resolved.
pub trait Asset {
/// Resolves the dependencies of the asset.
Expand Down
80 changes: 78 additions & 2 deletions core/common/src/abstractions/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,78 @@
use std::{marker::PhantomData, panic::RefUnwindSafe};
use std::{
fmt::{Debug, Formatter},
marker::PhantomData,
panic::RefUnwindSafe,
sync::Arc,
};

use crate::{FromVariant, ToVariant, Variant};

use super::VariantError;

/// An error when calling a script callback.
#[derive(Debug)]
pub enum CallbackError {
ExecutionError(String),
InvalidArgument,
}

/// A callback that can be called from a foreign environment.
/// A boxed callable function.
///
/// This is a wrapper around a boxed function that can be called with a list of
/// [`Variant`] arguments and returns a [`Variant`] result.
#[derive(Clone)]
pub struct Callable(Arc<dyn Fn(&[Variant]) -> Result<Variant, CallbackError>>);

impl Callable {
/// Creates a new boxed callable function from the given function.
pub fn new(function: impl Fn(&[Variant]) -> Result<Variant, CallbackError> + 'static) -> Self {
Self(Arc::new(function))
}

/// Creates a new boxed callable function from the given [`Callback`].
pub fn from_callback<R>(callback: impl Callback<R> + 'static) -> Self {
Self(Arc::new(move |args| callback.call(args)))
}

/// Calls the boxed callable function with the given arguments.
pub fn call(&self, args: &[Variant]) -> Result<Variant, CallbackError> {
let callable = self.0.as_ref();

callable(args)
}
}

impl PartialEq for Callable {
#[inline]
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}

impl Debug for Callable {
fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
write!(formatter, "Callable")
}
}

impl ToVariant for Callable {
#[inline]
fn to_variant(&self) -> Variant {
Variant::Callable(self.clone())
}
}

impl FromVariant for Callable {
#[inline]
fn from_variant(variant: Variant) -> Result<Self, VariantError> {
match variant {
Variant::Callable(callable) => Ok(callable),
_ => Err(VariantError::InvalidConversion),
}
}
}

/// Represents a function signature that is callable.
pub trait Callback<R>: RefUnwindSafe {
/// Calls the callback with the given arguments.
fn call(&self, args: &[Variant]) -> Result<Variant, CallbackError>;
Expand Down Expand Up @@ -135,3 +198,16 @@ where
Ok(result.to_variant())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_callable_function_creation_and_execution() {
let callable = Callable::from_callback(|a: u32, b: u32| a + b);
let result = callable.call(&[Variant::U32(1), Variant::U32(2)]).unwrap();

assert_eq!(result, Variant::U32(3));
}
}
46 changes: 7 additions & 39 deletions core/common/src/abstractions/variant.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{any::Any, cmp::Ordering};
use std::{any::Any, cmp::Ordering, fmt::Debug};

use crate::{Color, Color32, Pointer, Quat, StringName, Vec2, Vec3, Vec4};
use crate::{Callable, Color, Color32, Pointer, Quat, StringName, Vec2, Vec3, Vec4};

/// An error that can occur when working with variants.
#[derive(Debug)]
Expand Down Expand Up @@ -47,6 +47,7 @@ pub enum VariantKind {
Quat,
Color,
Color32,
Callable,
UserData,
}

Expand All @@ -55,7 +56,7 @@ pub enum VariantKind {
/// This is an abstraction over the different primitive types that are often
/// shuffled around in the engine. It allows for a more generic API that can
/// handle any type of value.
#[derive(Default, Debug)]
#[derive(Default, Debug, PartialEq)]
pub enum Variant {
#[default]
Null,
Expand All @@ -79,45 +80,10 @@ pub enum Variant {
Quat(Quat),
Color(Color),
Color32(Color32),
Callable(Callable),
UserData(Pointer<dyn Any>),
}

impl PartialEq for Variant {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Variant::Null, Variant::Null) => true,
(Variant::Bool(a), Variant::Bool(b)) => a == b,
(Variant::Char(a), Variant::Char(b)) => a == b,
(Variant::U8(a), Variant::U8(b)) => a == b,
(Variant::U16(a), Variant::U16(b)) => a == b,
(Variant::U32(a), Variant::U32(b)) => a == b,
(Variant::U64(a), Variant::U64(b)) => a == b,
(Variant::I8(a), Variant::I8(b)) => a == b,
(Variant::I16(a), Variant::I16(b)) => a == b,
(Variant::I32(a), Variant::I32(b)) => a == b,
(Variant::I64(a), Variant::I64(b)) => a == b,
(Variant::F32(a), Variant::F32(b)) => a == b,
(Variant::F64(a), Variant::F64(b)) => a == b,
(Variant::String(a), Variant::String(b)) => a == b,
(Variant::StringName(a), Variant::StringName(b)) => a == b,
(Variant::Vec2(a), Variant::Vec2(b)) => a == b,
(Variant::Vec3(a), Variant::Vec3(b)) => a == b,
(Variant::Vec4(a), Variant::Vec4(b)) => a == b,
(Variant::Quat(a), Variant::Quat(b)) => a == b,
(Variant::Color(a), Variant::Color(b)) => a == b,
(Variant::Color32(a), Variant::Color32(b)) => a == b,
(Variant::UserData(a), Variant::UserData(b)) => {
// pointer comparison
let ptr_a = &**a as *const dyn Any;
let ptr_b = &**b as *const dyn Any;

std::ptr::addr_eq(ptr_a, ptr_b)
}
_ => false,
}
}
}

impl Clone for Variant {
fn clone(&self) -> Self {
match self {
Expand All @@ -142,6 +108,7 @@ impl Clone for Variant {
Variant::Quat(value) => Variant::Quat(value.clone()),
Variant::Color(value) => Variant::Color(value.clone()),
Variant::Color32(value) => Variant::Color32(value.clone()),
Variant::Callable(value) => Variant::Callable(value.clone()),
Variant::UserData(value) => Variant::UserData(value.clone()),
}
}
Expand Down Expand Up @@ -172,6 +139,7 @@ impl Variant {
Variant::Quat(_) => VariantKind::Quat,
Variant::Color(_) => VariantKind::Color,
Variant::Color32(_) => VariantKind::Color32,
Variant::Callable(_) => VariantKind::Callable,
Variant::UserData(_) => VariantKind::UserData,
}
}
Expand Down
3 changes: 2 additions & 1 deletion core/common/src/io/formats/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ impl FileFormat for BinaryFileFormat {
stream.write_u8(value.b)?;
stream.write_u8(value.a)?;
}
Variant::UserData(_) => {}
Variant::Callable(_) => todo!(),
Variant::UserData(_) => todo!(),
}
}
Chunk::Sequence(sequence) => {
Expand Down
3 changes: 2 additions & 1 deletion core/common/src/io/formats/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ impl FileFormat for JsonFileFormat {
Variant::Color32(value) => {
stream.write_string(&format!("[{}, {}, {}, {}]", value.r, value.g, value.b, value.a))?;
}
Variant::UserData(_) => {}
Variant::Callable(_) => todo!(),
Variant::UserData(_) => todo!(),
},
Chunk::Sequence(sequence) => {
stream.write_string("[")?;
Expand Down
29 changes: 27 additions & 2 deletions core/common/src/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
pub use mlua::prelude::*;

use crate::{
Callback, Color, Color32, FromVariant, Pointer, Quat, ToStringName, ToVariant, ToVirtualPath, Variant, Vec2, Vec3,
Vec4,
Callable, Callback, CallbackError, Color, Color32, FromVariant, Pointer, Quat, ToStringName, ToVariant, ToVirtualPath, Variant, Vec2, Vec3, Vec4
};

/// A Lua scripting engine.
Expand Down Expand Up @@ -136,6 +135,23 @@ impl<'lua> IntoLua<'lua> for Variant {
Variant::Quat(value) => LuaQuat(value).into_lua(lua)?,
Variant::Color(value) => LuaColor(value).into_lua(lua)?,
Variant::Color32(value) => LuaColor32(value).into_lua(lua)?,
Variant::Callable(callable) => {
// create a Lua function that calls the callable
let function = lua.create_function(move |lua, args: LuaMultiValue| {
let args = args
.into_iter()
.map(|value| Variant::from_lua(value, lua))
.collect::<LuaResult<Vec<_>>>()?;

let result = callable.call(&args).map_err(|error| {
LuaError::RuntimeError(format!("An error occurred calling a function, {:?}", error))
})?;

Ok(result.into_lua(lua)?)
})?;

LuaValue::Function(function)
},
Variant::UserData(value) => LuaValue::LightUserData(LuaLightUserData(value.into_void())),
})
}
Expand All @@ -151,6 +167,15 @@ impl<'lua> FromLua<'lua> for Variant {
LuaValue::Number(value) => Variant::F64(value),
LuaValue::String(value) => Variant::String(value.to_str()?.to_string()),
LuaValue::Table(value) => Variant::UserData(Pointer::new(value.into_owned())),
LuaValue::Function(function) => {
// create a callable that calls the Lua function
let function = function.into_owned();
let callable = Callable::new(move |args| {
function.call(args).map_err(|error| CallbackError::ExecutionError(error.to_string()))
});

Variant::Callable(callable)
},
LuaValue::LightUserData(value) => Variant::UserData(Pointer::from_raw_mut(value.0)),
LuaValue::UserData(value) => match () {
_ if value.is::<LuaVec2>() => Variant::Vec2(value.borrow::<LuaVec2>()?.0),
Expand Down
7 changes: 7 additions & 0 deletions core/common/src/memory/pointer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ impl<T: ?Sized> Pointer<T> {
}
}

/// Pointer value equality.
impl<T: ?Sized> PartialEq for Pointer<T> {
fn eq(&self, other: &Self) -> bool {
std::ptr::addr_eq(self.ptr, other.ptr)
}
}

/// Allow printing of the pointer.
impl<T: ?Sized> Debug for Pointer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down

0 comments on commit 9256918

Please sign in to comment.