Skip to content

Commit 1635903

Browse files
committed
Improve/fix scoped UserData drop
1 parent 2b2df70 commit 1635903

File tree

3 files changed

+149
-105
lines changed

3 files changed

+149
-105
lines changed

src/scope.rs

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,14 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
186186
let state = u.lua.state;
187187
assert_stack(state, 2);
188188
u.lua.push_ref(&u);
189+
190+
// Clear uservalue
191+
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
192+
ffi::lua_pushnil(state);
193+
#[cfg(any(feature = "lua51", feature = "luajit"))]
194+
ffi::lua_newtable(state);
195+
ffi::lua_setuservalue(state, -2);
196+
189197
// We know the destructor has not run yet because we hold a reference to the
190198
// userdata.
191199
vec![Box::new(take_userdata::<UserDataCell<T>>(state))]
@@ -244,28 +252,28 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
244252
let check_ud_type = move |lua: &'callback Lua, value| {
245253
if let Some(Value::UserData(ud)) = value {
246254
unsafe {
247-
assert_stack(lua.state, 1);
255+
let _sg = StackGuard::new(lua.state);
256+
assert_stack(lua.state, 3);
248257
lua.push_ref(&ud.0);
249-
ffi::lua_getuservalue(lua.state, -1);
250-
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
251-
{
252-
ffi::lua_rawgeti(lua.state, -1, 1);
253-
ffi::lua_remove(lua.state, -2);
258+
if ffi::lua_getmetatable(lua.state, -1) == 0 {
259+
return Err(Error::UserDataTypeMismatch);
260+
}
261+
ffi::lua_pushstring(lua.state, cstr!("__mlua"));
262+
if ffi::lua_rawget(lua.state, -2) == ffi::LUA_TLIGHTUSERDATA {
263+
let ud_ptr = ffi::lua_touserdata(lua.state, -1);
264+
if ud_ptr == check_data.as_ptr() as *mut c_void {
265+
return Ok(());
266+
}
254267
}
255-
return ffi::lua_touserdata(lua.state, -1)
256-
== check_data.as_ptr() as *mut c_void;
257268
}
258-
}
259-
260-
false
269+
};
270+
Err(Error::UserDataTypeMismatch)
261271
};
262272

263273
match method {
264274
NonStaticMethod::Method(method) => {
265275
let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
266-
if !check_ud_type(lua, args.pop_front()) {
267-
return Err(Error::UserDataTypeMismatch);
268-
}
276+
check_ud_type(lua, args.pop_front())?;
269277
let data = data
270278
.try_borrow()
271279
.map(|cell| Ref::map(cell, AsRef::as_ref))
@@ -277,9 +285,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
277285
NonStaticMethod::MethodMut(method) => {
278286
let method = RefCell::new(method);
279287
let f = Box::new(move |lua, mut args: MultiValue<'callback>| {
280-
if !check_ud_type(lua, args.pop_front()) {
281-
return Err(Error::UserDataTypeMismatch);
282-
}
288+
check_ud_type(lua, args.pop_front())?;
283289
let mut method = method
284290
.try_borrow_mut()
285291
.map_err(|_| Error::RecursiveMutCallback)?;
@@ -314,24 +320,19 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
314320
unsafe {
315321
let lua = self.lua;
316322
let _sg = StackGuard::new(lua.state);
317-
assert_stack(lua.state, 6);
323+
assert_stack(lua.state, 13);
318324

319325
// We need to wrap dummy userdata because their memory can be accessed by serializer
320326
push_userdata(lua.state, UserDataCell::new(UserDataWrapped::new(())))?;
321-
#[cfg(any(feature = "lua54", feature = "lua53"))]
322-
ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void);
323-
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
324-
protect_lua_closure(lua.state, 0, 1, |state| {
325-
// Lua 5.2/5.1 allows to store only table. Then we will wrap the value.
326-
ffi::lua_createtable(state, 1, 0);
327-
ffi::lua_pushlightuserdata(state, data.as_ptr() as *mut c_void);
328-
ffi::lua_rawseti(state, -2, 1);
329-
})?;
330-
ffi::lua_setuservalue(lua.state, -2);
331327

332328
// Prepare metatable, add meta methods first and then meta fields
333-
protect_lua_closure(lua.state, 0, 1, move |state| {
329+
protect_lua_closure(lua.state, 0, 1, |state| {
334330
ffi::lua_newtable(state);
331+
332+
// Add internal metamethod to store reference to the data
333+
ffi::lua_pushstring(state, cstr!("__mlua"));
334+
ffi::lua_pushlightuserdata(lua.state, data.as_ptr() as *mut c_void);
335+
ffi::lua_rawset(state, -3);
335336
})?;
336337
for (k, m) in ud_methods.meta_methods {
337338
push_string(lua.state, k.validate()?.name())?;
@@ -415,19 +416,31 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
415416

416417
let mt_id = ffi::lua_topointer(lua.state, -1);
417418
ffi::lua_setmetatable(lua.state, -2);
418-
419419
let ud = AnyUserData(lua.pop_ref());
420420
lua.register_userdata_metatable(mt_id as isize);
421+
421422
self.destructors.borrow_mut().push((ud.0.clone(), |ud| {
423+
// We know the destructor has not run yet because we hold a reference to the userdata.
422424
let state = ud.lua.state;
423425
assert_stack(state, 2);
424426
ud.lua.push_ref(&ud);
427+
428+
// Deregister metatable
425429
ffi::lua_getmetatable(state, -1);
426430
let mt_id = ffi::lua_topointer(state, -1);
427431
ffi::lua_pop(state, 1);
428432
ud.lua.deregister_userdata_metatable(mt_id as isize);
433+
434+
// Clear uservalue
435+
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))]
436+
ffi::lua_pushnil(state);
437+
#[cfg(any(feature = "lua51", feature = "luajit"))]
438+
ffi::lua_newtable(state);
439+
ffi::lua_setuservalue(state, -2);
440+
429441
vec![Box::new(take_userdata::<UserDataCell<()>>(state))]
430442
}));
443+
431444
Ok(ud)
432445
}
433446
}

src/userdata.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ impl MetaMethod {
186186
MetaMethod::Custom(name) if name == "__metatable" => {
187187
Err(Error::MetaMethodRestricted(name))
188188
}
189+
MetaMethod::Custom(name) if name == "__mlua" => Err(Error::MetaMethodRestricted(name)),
189190
_ => Ok(self),
190191
}
191192
}

tests/scope.rs

Lines changed: 105 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::cell::Cell;
22
use std::rc::Rc;
3+
use std::sync::Arc;
34

45
use mlua::{
56
AnyUserData, Error, Function, Lua, MetaMethod, Result, String, UserData, UserDataFields,
@@ -26,84 +27,13 @@ fn scope_func() -> Result<()> {
2627
assert_eq!(Rc::strong_count(&rc), 1);
2728

2829
match lua.globals().get::<_, Function>("bad")?.call::<_, ()>(()) {
29-
Err(Error::CallbackError { .. }) => {}
30-
r => panic!("improper return for destructed function: {:?}", r),
31-
};
32-
33-
Ok(())
34-
}
35-
36-
#[test]
37-
fn scope_drop() -> Result<()> {
38-
let lua = Lua::new();
39-
40-
struct MyUserdata(Rc<()>);
41-
impl UserData for MyUserdata {
42-
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
43-
methods.add_method("method", |_, _, ()| Ok(()));
44-
}
45-
}
46-
47-
let rc = Rc::new(());
48-
49-
lua.scope(|scope| {
50-
lua.globals()
51-
.set("static_ud", scope.create_userdata(MyUserdata(rc.clone()))?)?;
52-
assert_eq!(Rc::strong_count(&rc), 2);
53-
Ok(())
54-
})?;
55-
assert_eq!(Rc::strong_count(&rc), 1);
56-
57-
match lua.load("static_ud:method()").exec() {
58-
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
59-
Error::CallbackDestructed => {}
60-
e => panic!("expected CallbackDestructed, got {:?}", e),
61-
},
62-
r => panic!("improper return for destructed userdata: {:?}", r),
63-
};
64-
65-
let static_ud = lua.globals().get::<_, AnyUserData>("static_ud")?;
66-
match static_ud.borrow::<MyUserdata>() {
67-
Ok(_) => panic!("borrowed destructed userdata"),
68-
Err(Error::UserDataDestructed) => {}
69-
Err(e) => panic!("expected UserDataDestructed, got {:?}", e),
70-
}
71-
72-
// Check non-static UserData drop
73-
struct MyUserDataRef<'a>(&'a Cell<i64>);
74-
75-
impl<'a> UserData for MyUserDataRef<'a> {
76-
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
77-
methods.add_method("inc", |_, data, ()| {
78-
data.0.set(data.0.get() + 1);
79-
Ok(())
80-
});
81-
}
82-
}
83-
84-
let i = Cell::new(1);
85-
lua.scope(|scope| {
86-
lua.globals().set(
87-
"nonstatic_ud",
88-
scope.create_nonstatic_userdata(MyUserDataRef(&i))?,
89-
)
90-
})?;
91-
92-
match lua.load("nonstatic_ud:inc(1)").exec() {
93-
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
30+
Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() {
9431
Error::CallbackDestructed => {}
95-
e => panic!("expected CallbackDestructed, got {:?}", e),
32+
ref err => panic!("wrong error type {:?}", err),
9633
},
97-
r => panic!("improper return for destructed userdata: {:?}", r),
34+
r => panic!("improper return for destructed function: {:?}", r),
9835
};
9936

100-
let nonstatic_ud = lua.globals().get::<_, AnyUserData>("nonstatic_ud")?;
101-
match nonstatic_ud.borrow::<MyUserDataRef>() {
102-
Ok(_) => panic!("borrowed destructed userdata"),
103-
Err(Error::UserDataDestructed) => {}
104-
Err(e) => panic!("expected UserDataDestructed, got {:?}", e),
105-
}
106-
10737
Ok(())
10838
}
10939

@@ -126,7 +56,7 @@ fn scope_capture() -> Result<()> {
12656
}
12757

12858
#[test]
129-
fn outer_lua_access() -> Result<()> {
59+
fn scope_outer_lua_access() -> Result<()> {
13060
let lua = Lua::new();
13161

13262
let table = lua.create_table()?;
@@ -309,3 +239,103 @@ fn scope_userdata_mismatch() -> Result<()> {
309239

310240
Ok(())
311241
}
242+
243+
#[test]
244+
fn scope_userdata_drop() -> Result<()> {
245+
let lua = Lua::new();
246+
247+
struct MyUserData(Rc<()>);
248+
249+
impl UserData for MyUserData {
250+
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
251+
methods.add_method("method", |_, _, ()| Ok(()));
252+
}
253+
}
254+
255+
struct MyUserDataArc(Arc<()>);
256+
257+
impl UserData for MyUserDataArc {}
258+
259+
let rc = Rc::new(());
260+
let arc = Arc::new(());
261+
lua.scope(|scope| {
262+
let ud = scope.create_userdata(MyUserData(rc.clone()))?;
263+
ud.set_user_value(MyUserDataArc(arc.clone()))?;
264+
lua.globals().set("ud", ud)?;
265+
assert_eq!(Rc::strong_count(&rc), 2);
266+
assert_eq!(Arc::strong_count(&arc), 2);
267+
Ok(())
268+
})?;
269+
270+
lua.gc_collect()?;
271+
assert_eq!(Rc::strong_count(&rc), 1);
272+
assert_eq!(Arc::strong_count(&arc), 1);
273+
274+
match lua.load("ud:method()").exec() {
275+
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
276+
Error::CallbackDestructed => {}
277+
err => panic!("expected CallbackDestructed, got {:?}", err),
278+
},
279+
r => panic!("improper return for destructed userdata: {:?}", r),
280+
};
281+
282+
let ud = lua.globals().get::<_, AnyUserData>("ud")?;
283+
match ud.borrow::<MyUserData>() {
284+
Ok(_) => panic!("succesfull borrow for destructed userdata"),
285+
Err(Error::UserDataDestructed) => {}
286+
Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err),
287+
}
288+
289+
Ok(())
290+
}
291+
292+
#[test]
293+
fn scope_nonstatic_userdata_drop() -> Result<()> {
294+
let lua = Lua::new();
295+
296+
struct MyUserData<'a>(&'a Cell<i64>);
297+
298+
impl<'a> UserData for MyUserData<'a> {
299+
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
300+
methods.add_method("inc", |_, data, ()| {
301+
data.0.set(data.0.get() + 1);
302+
Ok(())
303+
});
304+
}
305+
}
306+
307+
struct MyUserDataArc(Arc<()>);
308+
309+
impl UserData for MyUserDataArc {}
310+
311+
let i = Cell::new(1);
312+
let arc = Arc::new(());
313+
lua.scope(|scope| {
314+
let ud = scope.create_nonstatic_userdata(MyUserData(&i))?;
315+
ud.set_user_value(MyUserDataArc(arc.clone()))?;
316+
lua.globals().set("ud", ud)?;
317+
lua.load("ud:inc()").exec()?;
318+
assert_eq!(Arc::strong_count(&arc), 2);
319+
Ok(())
320+
})?;
321+
322+
lua.gc_collect()?;
323+
assert_eq!(Arc::strong_count(&arc), 1);
324+
325+
match lua.load("ud:inc()").exec() {
326+
Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() {
327+
Error::CallbackDestructed => {}
328+
err => panic!("expected CallbackDestructed, got {:?}", err),
329+
},
330+
r => panic!("improper return for destructed userdata: {:?}", r),
331+
};
332+
333+
let ud = lua.globals().get::<_, AnyUserData>("ud")?;
334+
match ud.borrow::<MyUserData>() {
335+
Ok(_) => panic!("succesfull borrow for destructed userdata"),
336+
Err(Error::UserDataDestructed) => {}
337+
Err(err) => panic!("improper borrow error for destructed userdata: {:?}", err),
338+
}
339+
340+
Ok(())
341+
}

0 commit comments

Comments
 (0)