diff --git a/core/common/src/abstractions/callbacks.rs b/core/common/src/abstractions/callbacks.rs index bebfccf6..cdee623b 100644 --- a/core/common/src/abstractions/callbacks.rs +++ b/core/common/src/abstractions/callbacks.rs @@ -1,5 +1,5 @@ use std::{ - fmt::{Debug, Formatter}, + fmt::{Debug, Display, Formatter}, marker::PhantomData, panic::RefUnwindSafe, sync::Arc, @@ -14,6 +14,19 @@ pub enum CallbackError { InvalidArgument, } +impl Display for CallbackError { + fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CallbackError::ExecutionError(message) => { + write!(formatter, "Execution error: {}", message) + } + CallbackError::InvalidArgument => { + write!(formatter, "Invalid argument") + } + } + } +} + /// A boxed callable function. /// /// This is a wrapper around a boxed function that can be called with a list of diff --git a/core/common/src/lua.rs b/core/common/src/lua.rs index e49c7b9d..58da7e24 100644 --- a/core/common/src/lua.rs +++ b/core/common/src/lua.rs @@ -6,8 +6,8 @@ pub use mlua::prelude::*; use crate::{ - Callable, Callback, CallbackError, Color, Color32, FromVariant, Pointer, Quat, ToStringName, ToVariant, - ToVirtualPath, Variant, Vec2, Vec3, Vec4, + Callable, Callback, CallbackError, Color, Color32, FromVariant, Pointer, Quat, ToVariant, ToVirtualPath, Variant, + Vec2, Vec3, Vec4, }; /// A Lua scripting engine. @@ -31,12 +31,12 @@ impl LuaScriptEngine { // configure common globals let globals = engine.globals(); - globals.set_function("vec2", Vec2::new)?; - globals.set_function("vec3", Vec3::new)?; - globals.set_function("vec4", Vec4::new)?; - globals.set_function("quat", Quat::from_xyzw)?; - globals.set_function("rgb", Color::rgb)?; - globals.set_function("rgba", Color::rgba)?; + globals.set_variant_function("vec2", Vec2::new)?; + globals.set_variant_function("vec3", Vec3::new)?; + globals.set_variant_function("vec4", Vec4::new)?; + globals.set_variant_function("quat", Quat::from_xyzw)?; + globals.set_variant_function("rgb", Color::rgb)?; + globals.set_variant_function("rgba", Color::rgba)?; } Ok(engine) @@ -67,61 +67,8 @@ impl LuaScriptEngine { } /// Gets the global table from the Lua state. - pub fn globals(&self) -> LuaScriptTable { - LuaScriptTable { - lua: &self.lua, - table: self.lua.globals(), - } - } -} - -/// A wrapper over a [`LuaTable`] for simplified access. -pub struct LuaScriptTable<'lua> { - lua: &'lua Lua, - table: LuaTable<'lua>, -} - -impl<'lua> LuaScriptTable<'lua> { - /// Gets a value from the table. - pub fn get>(&self, name: &str) -> LuaResult { - self.table.get(name) - } - - /// Sets a value in the table. - pub fn set>(&self, name: &str, value: R) -> LuaResult<()> { - self.table.set(name, value) - } - - /// Gets a sub-table from the table. - pub fn get_table(&self, name: &str) -> LuaResult { - Ok(LuaScriptTable { - lua: self.lua, - table: self.table.get(name)?, - }) - } - - /// Sets a function in the table. - pub fn set_function(&self, name: &str, callback: impl Callback + 'static) -> LuaResult<()> { - // build a closure that can be called from Lua - let function_name = name.to_string_name(); // pool string names - - let body = move |lua, args: LuaMultiValue| { - let args = args - .into_iter() - .map(|value| Variant::from_lua(value, lua)) - .collect::>>()?; - - let result = callback.call(&args).map_err(|error| { - // make it clear which function caused the error - LuaError::RuntimeError(format!("An error occurred calling {}, {:?}", &function_name, error)) - })?; - - Ok(result.into_lua(lua)?) - }; - - self.table.set(name, self.lua.create_function(body)?)?; - - Ok(()) + pub fn globals(&self) -> LuaTable { + self.lua.globals() } } @@ -152,7 +99,6 @@ impl<'lua> IntoLua<'lua> for Variant { Variant::Color32(value) => LuaColor32(value).into_lua(lua)?, Variant::Callable(callable) => { // create a Lua function that calls the callable - // TODO: clean this up? let function = lua.create_function(move |lua, args: LuaMultiValue| { let args = args .into_iter() @@ -161,7 +107,7 @@ impl<'lua> IntoLua<'lua> for Variant { let result = callable .call(&args) - .map_err(|error| LuaError::RuntimeError(format!("An error occurred calling a function, {:?}", error)))?; + .map_err(|error| LuaError::RuntimeError(error.to_string()))?; Ok(result.into_lua(lua)?) })?; @@ -188,7 +134,7 @@ impl<'lua> FromLua<'lua> for Variant { let function = function.into_owned(); let callable = Callable::new(move |args| { function - .call(args) + .call(VariantArray(args.to_vec())) .map_err(|error| CallbackError::ExecutionError(error.to_string())) }); @@ -209,22 +155,47 @@ impl<'lua> FromLua<'lua> for Variant { } } +/// Wraps many [`Variant`]s to be passed as arguments to a Lua function. +/// +/// This is necessary because Lua does not support variadic arguments. +struct VariantArray(Vec); + +impl<'lua> IntoLuaMulti<'lua> for VariantArray { + fn into_lua_multi(self, lua: &'lua Lua) -> LuaResult> { + self.0.into_iter().map(|value| value.into_lua(lua)).collect() + } +} + /// Extension methods for [`LuaTable`] to work with [`Variant`]s. pub trait VariantTableExt<'lua> { - fn get_as(&self, key: impl IntoLua<'lua>) -> LuaResult; - fn set_as(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()>; + fn get_variant(&self, key: impl IntoLua<'lua>) -> LuaResult; + fn set_variant(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()>; + + /// Calls a function in the table. + fn call_variant_function(&self, key: impl IntoLua<'lua>, args: &[Variant]) -> LuaResult { + let callable: Callable = self.get_variant(key)?; + + callable.call(args).map_err(|it| LuaError::RuntimeError(it.to_string())) + } + + /// Sets a function in the table. + fn set_variant_function(&self, key: impl IntoLua<'lua>, callback: impl Callback + 'static) -> LuaResult<()> { + let callable = Callable::from_callback(callback); + + self.set_variant(key, callable) + } } impl<'lua> VariantTableExt<'lua> for LuaTable<'lua> { #[inline] - fn get_as(&self, key: impl IntoLua<'lua>) -> LuaResult { + fn get_variant(&self, key: impl IntoLua<'lua>) -> LuaResult { let variant = self.get(key); variant.and_then(|value| R::from_variant(value).map_err(|_| LuaError::UserDataTypeMismatch)) } #[inline] - fn set_as(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()> { + fn set_variant(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()> { let variant = value.to_variant(); self.set(key, variant) @@ -278,8 +249,8 @@ impl<'lua> FromLua<'lua> for LuaVec2 { fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult { match value { LuaValue::Table(value) => { - let x = value.get_as("x")?; - let y = value.get_as("y")?; + let x = value.get_variant("x")?; + let y = value.get_variant("y")?; Ok(LuaVec2(Vec2::new(x, y))) } @@ -323,9 +294,9 @@ impl<'lua> FromLua<'lua> for LuaVec3 { fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult { match value { LuaValue::Table(value) => { - let x = value.get_as("x")?; - let y = value.get_as("y")?; - let z = value.get_as("z")?; + let x = value.get_variant("x")?; + let y = value.get_variant("y")?; + let z = value.get_variant("z")?; Ok(LuaVec3(Vec3::new(x, y, z))) } @@ -370,10 +341,10 @@ impl<'lua> FromLua<'lua> for LuaVec4 { fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult { match value { LuaValue::Table(value) => { - let x = value.get_as("x")?; - let y = value.get_as("y")?; - let z = value.get_as("z")?; - let w = value.get_as("w")?; + let x = value.get_variant("x")?; + let y = value.get_variant("y")?; + let z = value.get_variant("z")?; + let w = value.get_variant("w")?; Ok(LuaVec4(Vec4::new(x, y, z, w))) } @@ -418,10 +389,10 @@ impl<'lua> FromLua<'lua> for LuaQuat { fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult { match value { LuaValue::Table(value) => { - let x = value.get_as("x")?; - let y = value.get_as("y")?; - let z = value.get_as("z")?; - let w = value.get_as("w")?; + let x = value.get_variant("x")?; + let y = value.get_variant("y")?; + let z = value.get_variant("z")?; + let w = value.get_variant("w")?; Ok(LuaQuat(Quat::from_xyzw(x, y, z, w))) } @@ -463,10 +434,10 @@ impl<'lua> FromLua<'lua> for LuaColor { fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult { match value { LuaValue::Table(value) => { - let r = value.get_as("r")?; - let g = value.get_as("g")?; - let b = value.get_as("b")?; - let a = value.get_as("a").unwrap_or(1.0); + let r = value.get_variant("r")?; + let g = value.get_variant("g")?; + let b = value.get_variant("b")?; + let a = value.get_variant("a").unwrap_or(1.0); Ok(LuaColor(Color::rgba(r, g, b, a))) } @@ -506,10 +477,10 @@ impl<'lua> FromLua<'lua> for LuaColor32 { fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult { match value { LuaValue::Table(value) => { - let r = value.get_as("r")?; - let g = value.get_as("g")?; - let b = value.get_as("b")?; - let a = value.get_as("a").unwrap_or(255); + let r = value.get_variant("r")?; + let g = value.get_variant("g")?; + let b = value.get_variant("b")?; + let a = value.get_variant("a").unwrap_or(255); Ok(LuaColor32(Color32::rgba(r, g, b, a))) } @@ -549,7 +520,7 @@ mod tests { let lua = LuaScriptEngine::new().unwrap(); let globals = lua.globals(); - globals.set_function("vec2", Vec2::new).unwrap(); + globals.set_variant_function("vec2", Vec2::new).unwrap(); let script = r#" local a = vec2(1, 2) @@ -573,7 +544,7 @@ mod tests { let lua = LuaScriptEngine::new().unwrap(); let globals = lua.globals(); - globals.set_function("vec3", Vec3::new).unwrap(); + globals.set_variant_function("vec3", Vec3::new).unwrap(); let script = r#" local a = vec3(1, 2, 3) @@ -597,7 +568,7 @@ mod tests { let lua = LuaScriptEngine::new().unwrap(); let globals = lua.globals(); - globals.set_function("vec4", Vec4::new).unwrap(); + globals.set_variant_function("vec4", Vec4::new).unwrap(); let script = r#" local a = vec4(1, 2, 3, 4) @@ -621,7 +592,7 @@ mod tests { let lua = LuaScriptEngine::new().unwrap(); let globals = lua.globals(); - globals.set_function("quat", Quat::from_xyzw).unwrap(); + globals.set_variant_function("quat", Quat::from_xyzw).unwrap(); let script = r#" local a = quat(1, 2, 3, 4) @@ -645,7 +616,7 @@ mod tests { let lua = LuaScriptEngine::new().unwrap(); let globals = lua.globals(); - globals.set_function("color", Color::rgba).unwrap(); + globals.set_variant_function("color", Color::rgba).unwrap(); let script = r#" local a = color(1, 2, 3, 4) @@ -667,7 +638,7 @@ mod tests { let lua = LuaScriptEngine::new().unwrap(); let globals = lua.globals(); - globals.set_function("color32", Color32::rgba).unwrap(); + globals.set_variant_function("color32", Color32::rgba).unwrap(); let script = r#" local a = color32(1, 2, 3, 4) @@ -679,4 +650,40 @@ mod tests { lua.run(script).unwrap(); } + + #[test] + fn test_basic_call_from_lua() { + let lua = LuaScriptEngine::new().unwrap(); + let globals = lua.globals(); + + globals.set_variant_function("add", |a: i32, b: i32| a + b).unwrap(); + + let script = r#" + local a = add(1, 2) + + assert(a == 3) + "#; + + lua.run(script).unwrap(); + } + + #[test] + fn test_basic_call_from_rust() { + let lua = LuaScriptEngine::new().unwrap(); + + let script = r#" + function add(a, b) + return a + b + end + "#; + + lua.run(script).unwrap(); + + let globals = lua.globals(); + + let args = [Variant::I64(3), Variant::I64(4)]; + let result = globals.call_variant_function("add", &args).unwrap(); + + assert_eq!(result, Variant::I64(7)); + } }