Skip to content

Commit

Permalink
Some more work on Lua interop
Browse files Browse the repository at this point in the history
  • Loading branch information
mattkleiny committed Aug 5, 2024
1 parent e88c860 commit 04b3c59
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 98 deletions.
15 changes: 14 additions & 1 deletion core/common/src/abstractions/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
fmt::{Debug, Formatter},
fmt::{Debug, Display, Formatter},
marker::PhantomData,
panic::RefUnwindSafe,
sync::Arc,
Expand All @@ -14,6 +14,19 @@ pub enum CallbackError {
InvalidArgument,
}

impl Display for CallbackError {
fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
match self {
CallbackError::ExecutionError(message) => {
write!(formatter, "Execution error: {}", message)
}
CallbackError::InvalidArgument => {
write!(formatter, "Invalid argument")
}
}
}
}

/// A boxed callable function.
///
/// This is a wrapper around a boxed function that can be called with a list of
Expand Down
201 changes: 104 additions & 97 deletions core/common/src/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
pub use mlua::prelude::*;

use crate::{
Callable, Callback, CallbackError, Color, Color32, FromVariant, Pointer, Quat, ToStringName, ToVariant,
ToVirtualPath, Variant, Vec2, Vec3, Vec4,
Callable, Callback, CallbackError, Color, Color32, FromVariant, Pointer, Quat, ToVariant, ToVirtualPath, Variant,
Vec2, Vec3, Vec4,
};

/// A Lua scripting engine.
Expand All @@ -31,12 +31,12 @@ impl LuaScriptEngine {
// configure common globals
let globals = engine.globals();

globals.set_function("vec2", Vec2::new)?;
globals.set_function("vec3", Vec3::new)?;
globals.set_function("vec4", Vec4::new)?;
globals.set_function("quat", Quat::from_xyzw)?;
globals.set_function("rgb", Color::rgb)?;
globals.set_function("rgba", Color::rgba)?;
globals.set_variant_function("vec2", Vec2::new)?;
globals.set_variant_function("vec3", Vec3::new)?;
globals.set_variant_function("vec4", Vec4::new)?;
globals.set_variant_function("quat", Quat::from_xyzw)?;
globals.set_variant_function("rgb", Color::rgb)?;
globals.set_variant_function("rgba", Color::rgba)?;
}

Ok(engine)
Expand Down Expand Up @@ -67,61 +67,8 @@ impl LuaScriptEngine {
}

/// Gets the global table from the Lua state.
pub fn globals(&self) -> LuaScriptTable {
LuaScriptTable {
lua: &self.lua,
table: self.lua.globals(),
}
}
}

/// A wrapper over a [`LuaTable`] for simplified access.
pub struct LuaScriptTable<'lua> {
lua: &'lua Lua,
table: LuaTable<'lua>,
}

impl<'lua> LuaScriptTable<'lua> {
/// Gets a value from the table.
pub fn get<R: FromLua<'lua>>(&self, name: &str) -> LuaResult<R> {
self.table.get(name)
}

/// Sets a value in the table.
pub fn set<R: IntoLua<'lua>>(&self, name: &str, value: R) -> LuaResult<()> {
self.table.set(name, value)
}

/// Gets a sub-table from the table.
pub fn get_table(&self, name: &str) -> LuaResult<Self> {
Ok(LuaScriptTable {
lua: self.lua,
table: self.table.get(name)?,
})
}

/// Sets a function in the table.
pub fn set_function<R>(&self, name: &str, callback: impl Callback<R> + 'static) -> LuaResult<()> {
// build a closure that can be called from Lua
let function_name = name.to_string_name(); // pool string names

let body = move |lua, args: LuaMultiValue| {
let args = args
.into_iter()
.map(|value| Variant::from_lua(value, lua))
.collect::<LuaResult<Vec<_>>>()?;

let result = callback.call(&args).map_err(|error| {
// make it clear which function caused the error
LuaError::RuntimeError(format!("An error occurred calling {}, {:?}", &function_name, error))
})?;

Ok(result.into_lua(lua)?)
};

self.table.set(name, self.lua.create_function(body)?)?;

Ok(())
pub fn globals(&self) -> LuaTable {
self.lua.globals()
}
}

Expand Down Expand Up @@ -152,7 +99,6 @@ impl<'lua> IntoLua<'lua> for Variant {
Variant::Color32(value) => LuaColor32(value).into_lua(lua)?,
Variant::Callable(callable) => {
// create a Lua function that calls the callable
// TODO: clean this up?
let function = lua.create_function(move |lua, args: LuaMultiValue| {
let args = args
.into_iter()
Expand All @@ -161,7 +107,7 @@ impl<'lua> IntoLua<'lua> for Variant {

let result = callable
.call(&args)
.map_err(|error| LuaError::RuntimeError(format!("An error occurred calling a function, {:?}", error)))?;
.map_err(|error| LuaError::RuntimeError(error.to_string()))?;

Ok(result.into_lua(lua)?)
})?;
Expand All @@ -188,7 +134,7 @@ impl<'lua> FromLua<'lua> for Variant {
let function = function.into_owned();
let callable = Callable::new(move |args| {
function
.call(args)
.call(VariantArray(args.to_vec()))
.map_err(|error| CallbackError::ExecutionError(error.to_string()))
});

Expand All @@ -209,22 +155,47 @@ impl<'lua> FromLua<'lua> for Variant {
}
}

/// Wraps many [`Variant`]s to be passed as arguments to a Lua function.
///
/// This is necessary because Lua does not support variadic arguments.
struct VariantArray(Vec<Variant>);

impl<'lua> IntoLuaMulti<'lua> for VariantArray {
fn into_lua_multi(self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
self.0.into_iter().map(|value| value.into_lua(lua)).collect()
}
}

/// Extension methods for [`LuaTable`] to work with [`Variant`]s.
pub trait VariantTableExt<'lua> {
fn get_as<R: FromVariant>(&self, key: impl IntoLua<'lua>) -> LuaResult<R>;
fn set_as<R: ToVariant>(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()>;
fn get_variant<R: FromVariant>(&self, key: impl IntoLua<'lua>) -> LuaResult<R>;
fn set_variant<R: ToVariant>(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()>;

/// Calls a function in the table.
fn call_variant_function(&self, key: impl IntoLua<'lua>, args: &[Variant]) -> LuaResult<Variant> {
let callable: Callable = self.get_variant(key)?;

callable.call(args).map_err(|it| LuaError::RuntimeError(it.to_string()))
}

/// Sets a function in the table.
fn set_variant_function<R>(&self, key: impl IntoLua<'lua>, callback: impl Callback<R> + 'static) -> LuaResult<()> {
let callable = Callable::from_callback(callback);

self.set_variant(key, callable)
}
}

impl<'lua> VariantTableExt<'lua> for LuaTable<'lua> {
#[inline]
fn get_as<R: FromVariant>(&self, key: impl IntoLua<'lua>) -> LuaResult<R> {
fn get_variant<R: FromVariant>(&self, key: impl IntoLua<'lua>) -> LuaResult<R> {
let variant = self.get(key);

variant.and_then(|value| R::from_variant(value).map_err(|_| LuaError::UserDataTypeMismatch))
}

#[inline]
fn set_as<R: ToVariant>(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()> {
fn set_variant<R: ToVariant>(&self, key: impl IntoLua<'lua>, value: R) -> LuaResult<()> {
let variant = value.to_variant();

self.set(key, variant)
Expand Down Expand Up @@ -278,8 +249,8 @@ impl<'lua> FromLua<'lua> for LuaVec2 {
fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Table(value) => {
let x = value.get_as("x")?;
let y = value.get_as("y")?;
let x = value.get_variant("x")?;
let y = value.get_variant("y")?;

Ok(LuaVec2(Vec2::new(x, y)))
}
Expand Down Expand Up @@ -323,9 +294,9 @@ impl<'lua> FromLua<'lua> for LuaVec3 {
fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Table(value) => {
let x = value.get_as("x")?;
let y = value.get_as("y")?;
let z = value.get_as("z")?;
let x = value.get_variant("x")?;
let y = value.get_variant("y")?;
let z = value.get_variant("z")?;

Ok(LuaVec3(Vec3::new(x, y, z)))
}
Expand Down Expand Up @@ -370,10 +341,10 @@ impl<'lua> FromLua<'lua> for LuaVec4 {
fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Table(value) => {
let x = value.get_as("x")?;
let y = value.get_as("y")?;
let z = value.get_as("z")?;
let w = value.get_as("w")?;
let x = value.get_variant("x")?;
let y = value.get_variant("y")?;
let z = value.get_variant("z")?;
let w = value.get_variant("w")?;

Ok(LuaVec4(Vec4::new(x, y, z, w)))
}
Expand Down Expand Up @@ -418,10 +389,10 @@ impl<'lua> FromLua<'lua> for LuaQuat {
fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Table(value) => {
let x = value.get_as("x")?;
let y = value.get_as("y")?;
let z = value.get_as("z")?;
let w = value.get_as("w")?;
let x = value.get_variant("x")?;
let y = value.get_variant("y")?;
let z = value.get_variant("z")?;
let w = value.get_variant("w")?;

Ok(LuaQuat(Quat::from_xyzw(x, y, z, w)))
}
Expand Down Expand Up @@ -463,10 +434,10 @@ impl<'lua> FromLua<'lua> for LuaColor {
fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Table(value) => {
let r = value.get_as("r")?;
let g = value.get_as("g")?;
let b = value.get_as("b")?;
let a = value.get_as("a").unwrap_or(1.0);
let r = value.get_variant("r")?;
let g = value.get_variant("g")?;
let b = value.get_variant("b")?;
let a = value.get_variant("a").unwrap_or(1.0);

Ok(LuaColor(Color::rgba(r, g, b, a)))
}
Expand Down Expand Up @@ -506,10 +477,10 @@ impl<'lua> FromLua<'lua> for LuaColor32 {
fn from_lua(value: LuaValue<'lua>, _lua: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Table(value) => {
let r = value.get_as("r")?;
let g = value.get_as("g")?;
let b = value.get_as("b")?;
let a = value.get_as("a").unwrap_or(255);
let r = value.get_variant("r")?;
let g = value.get_variant("g")?;
let b = value.get_variant("b")?;
let a = value.get_variant("a").unwrap_or(255);

Ok(LuaColor32(Color32::rgba(r, g, b, a)))
}
Expand Down Expand Up @@ -549,7 +520,7 @@ mod tests {
let lua = LuaScriptEngine::new().unwrap();
let globals = lua.globals();

globals.set_function("vec2", Vec2::new).unwrap();
globals.set_variant_function("vec2", Vec2::new).unwrap();

let script = r#"
local a = vec2(1, 2)
Expand All @@ -573,7 +544,7 @@ mod tests {
let lua = LuaScriptEngine::new().unwrap();
let globals = lua.globals();

globals.set_function("vec3", Vec3::new).unwrap();
globals.set_variant_function("vec3", Vec3::new).unwrap();

let script = r#"
local a = vec3(1, 2, 3)
Expand All @@ -597,7 +568,7 @@ mod tests {
let lua = LuaScriptEngine::new().unwrap();
let globals = lua.globals();

globals.set_function("vec4", Vec4::new).unwrap();
globals.set_variant_function("vec4", Vec4::new).unwrap();

let script = r#"
local a = vec4(1, 2, 3, 4)
Expand All @@ -621,7 +592,7 @@ mod tests {
let lua = LuaScriptEngine::new().unwrap();
let globals = lua.globals();

globals.set_function("quat", Quat::from_xyzw).unwrap();
globals.set_variant_function("quat", Quat::from_xyzw).unwrap();

let script = r#"
local a = quat(1, 2, 3, 4)
Expand All @@ -645,7 +616,7 @@ mod tests {
let lua = LuaScriptEngine::new().unwrap();
let globals = lua.globals();

globals.set_function("color", Color::rgba).unwrap();
globals.set_variant_function("color", Color::rgba).unwrap();

let script = r#"
local a = color(1, 2, 3, 4)
Expand All @@ -667,7 +638,7 @@ mod tests {
let lua = LuaScriptEngine::new().unwrap();
let globals = lua.globals();

globals.set_function("color32", Color32::rgba).unwrap();
globals.set_variant_function("color32", Color32::rgba).unwrap();

let script = r#"
local a = color32(1, 2, 3, 4)
Expand All @@ -679,4 +650,40 @@ mod tests {

lua.run(script).unwrap();
}

#[test]
fn test_basic_call_from_lua() {
let lua = LuaScriptEngine::new().unwrap();
let globals = lua.globals();

globals.set_variant_function("add", |a: i32, b: i32| a + b).unwrap();

let script = r#"
local a = add(1, 2)
assert(a == 3)
"#;

lua.run(script).unwrap();
}

#[test]
fn test_basic_call_from_rust() {
let lua = LuaScriptEngine::new().unwrap();

let script = r#"
function add(a, b)
return a + b
end
"#;

lua.run(script).unwrap();

let globals = lua.globals();

let args = [Variant::I64(3), Variant::I64(4)];
let result = globals.call_variant_function("add", &args).unwrap();

assert_eq!(result, Variant::I64(7));
}
}

0 comments on commit 04b3c59

Please sign in to comment.