| 
 | 1 | +use std::result::Result as StdResult;  | 
 | 2 | +use std::sync::Arc;  | 
 | 3 | + | 
 | 4 | +use mlua::{  | 
 | 5 | +    AnyUserData, Error, Function, Integer as LuaInteger, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod,  | 
 | 6 | +    MultiValue, Result, SerializeOptions, Table, UserData, UserDataMethods, UserDataRefMut, Value,  | 
 | 7 | +};  | 
 | 8 | +use ouroboros::self_referencing;  | 
 | 9 | +use serde::{Serialize, Serializer};  | 
 | 10 | + | 
 | 11 | +use crate::bytes::StringOrBytes;  | 
 | 12 | + | 
 | 13 | +/// Represents a native YAML object in Lua.  | 
 | 14 | +#[derive(Clone)]  | 
 | 15 | +pub(crate) struct YamlObject {  | 
 | 16 | +    root: Arc<serde_yaml::Value>,  | 
 | 17 | +    current: *const serde_yaml::Value,  | 
 | 18 | +}  | 
 | 19 | + | 
 | 20 | +impl Serialize for YamlObject {  | 
 | 21 | +    fn serialize<S: Serializer>(&self, serializer: S) -> StdResult<S::Ok, S::Error> {  | 
 | 22 | +        self.current().serialize(serializer)  | 
 | 23 | +    }  | 
 | 24 | +}  | 
 | 25 | + | 
 | 26 | +impl YamlObject {  | 
 | 27 | +    /// Creates a new `YamlObject` from the given YAML value.  | 
 | 28 | +    ///  | 
 | 29 | +    /// SAFETY:  | 
 | 30 | +    /// The caller must ensure that `current` is a value inside `root`.  | 
 | 31 | +    unsafe fn new(root: &Arc<serde_yaml::Value>, current: &serde_yaml::Value) -> Self {  | 
 | 32 | +        let root = root.clone();  | 
 | 33 | +        YamlObject { root, current }  | 
 | 34 | +    }  | 
 | 35 | + | 
 | 36 | +    /// Returns a reference to the current YAML value.  | 
 | 37 | +    #[inline(always)]  | 
 | 38 | +    fn current(&self) -> &serde_yaml::Value {  | 
 | 39 | +        unsafe { &*self.current }  | 
 | 40 | +    }  | 
 | 41 | + | 
 | 42 | +    /// Returns a new `YamlObject` which points to the value at the given key.  | 
 | 43 | +    ///  | 
 | 44 | +    /// This operation is cheap and does not clone the underlying data.  | 
 | 45 | +    fn get(&self, key: Value) -> Option<Self> {  | 
 | 46 | +        let value = match key {  | 
 | 47 | +            Value::Integer(index) if index > 0 => self.current().get(index as usize - 1),  | 
 | 48 | +            Value::String(key) => key.to_str().ok().and_then(|s| self.current().get(&*s)),  | 
 | 49 | +            _ => None,  | 
 | 50 | +        }?;  | 
 | 51 | +        unsafe { Some(Self::new(&self.root, value)) }  | 
 | 52 | +    }  | 
 | 53 | + | 
 | 54 | +    /// Converts this `YamlObject` into a Lua `Value`.  | 
 | 55 | +    fn into_lua(self, lua: &Lua) -> Result<Value> {  | 
 | 56 | +        match self.current() {  | 
 | 57 | +            serde_yaml::Value::Null => Ok(Value::NULL),  | 
 | 58 | +            serde_yaml::Value::Bool(b) => Ok(Value::Boolean(*b)),  | 
 | 59 | +            serde_yaml::Value::Number(n) => {  | 
 | 60 | +                if let Some(n) = n.as_i64() {  | 
 | 61 | +                    Ok(Value::Integer(n as _))  | 
 | 62 | +                } else if let Some(n) = n.as_f64() {  | 
 | 63 | +                    Ok(Value::Number(n))  | 
 | 64 | +                } else {  | 
 | 65 | +                    Err(Error::ToLuaConversionError {  | 
 | 66 | +                        from: "number".to_string(),  | 
 | 67 | +                        to: "integer or float",  | 
 | 68 | +                        message: Some("number is too big to fit in a Lua integer".to_owned()),  | 
 | 69 | +                    })  | 
 | 70 | +                }  | 
 | 71 | +            }  | 
 | 72 | +            serde_yaml::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),  | 
 | 73 | +            value @ serde_yaml::Value::Sequence(_) | value @ serde_yaml::Value::Mapping(_) => {  | 
 | 74 | +                let obj_ud = lua.create_ser_userdata(unsafe { YamlObject::new(&self.root, value) })?;  | 
 | 75 | +                Ok(Value::UserData(obj_ud))  | 
 | 76 | +            }  | 
 | 77 | +            serde_yaml::Value::Tagged(tagged) => {  | 
 | 78 | +                // For tagged values, we'll return the value part and ignore the tag for simplicity  | 
 | 79 | +                let obj = unsafe { YamlObject::new(&self.root, &tagged.value) };  | 
 | 80 | +                obj.into_lua(lua)  | 
 | 81 | +            }  | 
 | 82 | +        }  | 
 | 83 | +    }  | 
 | 84 | + | 
 | 85 | +    fn lua_iterator(&self, lua: &Lua) -> Result<MultiValue> {  | 
 | 86 | +        match self.current() {  | 
 | 87 | +            serde_yaml::Value::Sequence(_) => {  | 
 | 88 | +                let next = Self::lua_array_iterator(lua)?;  | 
 | 89 | +                let iter_ud = AnyUserData::wrap(LuaYamlArrayIter {  | 
 | 90 | +                    value: self.clone(),  | 
 | 91 | +                    next: 1, // index starts at 1  | 
 | 92 | +                });  | 
 | 93 | +                (next, iter_ud).into_lua_multi(lua)  | 
 | 94 | +            }  | 
 | 95 | +            serde_yaml::Value::Mapping(_) => {  | 
 | 96 | +                let next = Self::lua_map_iterator(lua)?;  | 
 | 97 | +                let iter_builder = LuaYamlMapIterBuilder {  | 
 | 98 | +                    value: self.clone(),  | 
 | 99 | +                    iter_builder: |value| value.current().as_mapping().unwrap().iter(),  | 
 | 100 | +                };  | 
 | 101 | +                let iter_ud = AnyUserData::wrap(iter_builder.build());  | 
 | 102 | +                (next, iter_ud).into_lua_multi(lua)  | 
 | 103 | +            }  | 
 | 104 | +            _ => ().into_lua_multi(lua),  | 
 | 105 | +        }  | 
 | 106 | +    }  | 
 | 107 | + | 
 | 108 | +    /// Returns an iterator function for arrays.  | 
 | 109 | +    fn lua_array_iterator(lua: &Lua) -> Result<Function> {  | 
 | 110 | +        if let Ok(Some(f)) = lua.named_registry_value("__yaml_array_iterator") {  | 
 | 111 | +            return Ok(f);  | 
 | 112 | +        }  | 
 | 113 | + | 
 | 114 | +        let f = lua.create_function(|lua, mut it: UserDataRefMut<LuaYamlArrayIter>| {  | 
 | 115 | +            it.next += 1;  | 
 | 116 | +            match it.value.get(Value::Integer(it.next - 1)) {  | 
 | 117 | +                Some(next_value) => (it.next - 1, next_value.into_lua(lua)?).into_lua_multi(lua),  | 
 | 118 | +                None => ().into_lua_multi(lua),  | 
 | 119 | +            }  | 
 | 120 | +        })?;  | 
 | 121 | +        lua.set_named_registry_value("__yaml_array_iterator", &f)?;  | 
 | 122 | +        Ok(f)  | 
 | 123 | +    }  | 
 | 124 | + | 
 | 125 | +    /// Returns an iterator function for objects.  | 
 | 126 | +    fn lua_map_iterator(lua: &Lua) -> Result<Function> {  | 
 | 127 | +        if let Ok(Some(f)) = lua.named_registry_value("__yaml_map_iterator") {  | 
 | 128 | +            return Ok(f);  | 
 | 129 | +        }  | 
 | 130 | + | 
 | 131 | +        let f = lua.create_function(|lua, mut it: UserDataRefMut<LuaYamlMapIter>| {  | 
 | 132 | +            let root = it.borrow_value().root.clone();  | 
 | 133 | +            it.with_iter_mut(move |iter| match iter.next() {  | 
 | 134 | +                Some((key, value)) => {  | 
 | 135 | +                    // Convert YAML key to Lua value  | 
 | 136 | +                    let key = match key {  | 
 | 137 | +                        serde_yaml::Value::Null  | 
 | 138 | +                        | serde_yaml::Value::Bool(..)  | 
 | 139 | +                        | serde_yaml::Value::String(..)  | 
 | 140 | +                        | serde_yaml::Value::Number(..) => unsafe {  | 
 | 141 | +                            YamlObject::new(&root, key).into_lua(lua)?  | 
 | 142 | +                        },  | 
 | 143 | +                        _ => {  | 
 | 144 | +                            let err =  | 
 | 145 | +                                Error::runtime("only string/number/boolean keys are supported in YAML maps");  | 
 | 146 | +                            return Err(err);  | 
 | 147 | +                        }  | 
 | 148 | +                    };  | 
 | 149 | +                    let value = unsafe { YamlObject::new(&root, value) }.into_lua(lua)?;  | 
 | 150 | +                    (key, value).into_lua_multi(lua)  | 
 | 151 | +                }  | 
 | 152 | +                None => ().into_lua_multi(lua),  | 
 | 153 | +            })  | 
 | 154 | +        })?;  | 
 | 155 | +        lua.set_named_registry_value("__yaml_map_iterator", &f)?;  | 
 | 156 | +        Ok(f)  | 
 | 157 | +    }  | 
 | 158 | +}  | 
 | 159 | + | 
 | 160 | +impl From<serde_yaml::Value> for YamlObject {  | 
 | 161 | +    fn from(value: serde_yaml::Value) -> Self {  | 
 | 162 | +        let root = Arc::new(value);  | 
 | 163 | +        unsafe { Self::new(&root, &root) }  | 
 | 164 | +    }  | 
 | 165 | +}  | 
 | 166 | + | 
 | 167 | +impl UserData for YamlObject {  | 
 | 168 | +    fn register(registry: &mut mlua::UserDataRegistry<Self>) {  | 
 | 169 | +        registry.add_method("dump", |lua, this, ()| lua.to_value(this));  | 
 | 170 | + | 
 | 171 | +        registry.add_method("iter", |lua, this, ()| this.lua_iterator(lua));  | 
 | 172 | + | 
 | 173 | +        registry.add_meta_method(MetaMethod::Index, |lua, this, key: Value| {  | 
 | 174 | +            this.get(key)  | 
 | 175 | +                .map(|obj| obj.into_lua(lua))  | 
 | 176 | +                .unwrap_or(Ok(Value::Nil))  | 
 | 177 | +        });  | 
 | 178 | + | 
 | 179 | +        registry.add_meta_method(crate::METAMETHOD_ITER, |lua, this, ()| this.lua_iterator(lua));  | 
 | 180 | +    }  | 
 | 181 | +}  | 
 | 182 | + | 
 | 183 | +struct LuaYamlArrayIter {  | 
 | 184 | +    value: YamlObject,  | 
 | 185 | +    next: LuaInteger,  | 
 | 186 | +}  | 
 | 187 | + | 
 | 188 | +#[self_referencing]  | 
 | 189 | +struct LuaYamlMapIter {  | 
 | 190 | +    value: YamlObject,  | 
 | 191 | + | 
 | 192 | +    #[borrows(value)]  | 
 | 193 | +    #[covariant]  | 
 | 194 | +    iter: serde_yaml::mapping::Iter<'this>,  | 
 | 195 | +}  | 
 | 196 | + | 
 | 197 | +fn decode(lua: &Lua, (data, opts): (StringOrBytes, Option<Table>)) -> Result<StdResult<Value, String>> {  | 
 | 198 | +    let opts = opts.as_ref();  | 
 | 199 | +    let mut options = SerializeOptions::new();  | 
 | 200 | +    if let Some(enabled) = opts.and_then(|t| t.get::<bool>("set_array_metatable").ok()) {  | 
 | 201 | +        options = options.set_array_metatable(enabled);  | 
 | 202 | +    }  | 
 | 203 | + | 
 | 204 | +    let mut yaml: serde_yaml::Value = lua_try!(serde_yaml::from_slice(&data.as_bytes_deref()));  | 
 | 205 | +    lua_try!(yaml.apply_merge());  | 
 | 206 | +    Ok(Ok(lua.to_value_with(&yaml, options)?))  | 
 | 207 | +}  | 
 | 208 | + | 
 | 209 | +fn decode_native(lua: &Lua, data: StringOrBytes) -> Result<StdResult<Value, String>> {  | 
 | 210 | +    let mut yaml: serde_yaml::Value = lua_try!(serde_yaml::from_slice(&data.as_bytes_deref()));  | 
 | 211 | +    lua_try!(yaml.apply_merge());  | 
 | 212 | +    Ok(Ok(lua_try!(YamlObject::from(yaml).into_lua(lua))))  | 
 | 213 | +}  | 
 | 214 | + | 
 | 215 | +fn encode(value: Value, opts: Option<Table>) -> StdResult<String, String> {  | 
 | 216 | +    let opts = opts.as_ref();  | 
 | 217 | +    let mut value = value.to_serializable();  | 
 | 218 | + | 
 | 219 | +    if opts.and_then(|t| t.get::<bool>("relaxed").ok()) == Some(true) {  | 
 | 220 | +        value = value.deny_recursive_tables(false).deny_unsupported_types(false);  | 
 | 221 | +    }  | 
 | 222 | + | 
 | 223 | +    serde_yaml::to_string(&value).map_err(|e| e.to_string())  | 
 | 224 | +}  | 
 | 225 | + | 
 | 226 | +/// A loader for the `yaml` module.  | 
 | 227 | +fn loader(lua: &Lua) -> Result<Table> {  | 
 | 228 | +    let t = lua.create_table()?;  | 
 | 229 | +    t.set("decode", lua.create_function(decode)?)?;  | 
 | 230 | +    t.set("decode_native", lua.create_function(decode_native)?)?;  | 
 | 231 | +    t.set("encode", Function::wrap_raw(encode))?;  | 
 | 232 | +    Ok(t)  | 
 | 233 | +}  | 
 | 234 | + | 
 | 235 | +/// Registers the `yaml` module in the given Lua state.  | 
 | 236 | +pub fn register(lua: &Lua, name: Option<&str>) -> Result<Table> {  | 
 | 237 | +    let name = name.unwrap_or("@yaml");  | 
 | 238 | +    let value = loader(lua)?;  | 
 | 239 | +    lua.register_module(name, &value)?;  | 
 | 240 | +    Ok(value)  | 
 | 241 | +}  | 
0 commit comments