|
| 1 | +use alloc::string::{String, ToString}; |
| 2 | +use core::error::Error as CoreError; |
| 3 | +use core::ffi::CStr; |
| 4 | +use core::fmt::Display; |
| 5 | + |
| 6 | +use serde::de::{ |
| 7 | + self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, StdError, |
| 8 | + VariantAccess, Visitor, |
| 9 | +}; |
| 10 | +use serde::{forward_to_deserialize_any, Deserialize}; |
| 11 | + |
| 12 | +pub struct Parser<'de> { |
| 13 | + offset: usize, |
| 14 | + remaining_input: &'de [u8], |
| 15 | +} |
| 16 | + |
| 17 | +impl<'de> Parser<'de> { |
| 18 | + fn error(&self, kind: ErrorKind) -> BsonError { |
| 19 | + BsonError { |
| 20 | + offset: Some(self.offset), |
| 21 | + kind: kind, |
| 22 | + } |
| 23 | + } |
| 24 | + |
| 25 | + fn advance(&mut self, by: usize) { |
| 26 | + self.offset = self.offset.strict_add(by); |
| 27 | + self.remaining_input = &self.remaining_input[by..]; |
| 28 | + } |
| 29 | + |
| 30 | + fn advance_checked(&mut self, size: usize) -> Result<&'de [u8], BsonError> { |
| 31 | + let (taken, rest) = self |
| 32 | + .remaining_input |
| 33 | + .split_at_checked(size) |
| 34 | + .ok_or_else(|| self.error(ErrorKind::UnexpectedEoF))?; |
| 35 | + |
| 36 | + self.offset += size; |
| 37 | + self.remaining_input = rest; |
| 38 | + Ok(taken) |
| 39 | + } |
| 40 | + |
| 41 | + fn advance_byte(&mut self) -> Result<u8, BsonError> { |
| 42 | + let slice = self.advance_checked(1)?; |
| 43 | + Ok(slice[0]) |
| 44 | + } |
| 45 | + |
| 46 | + fn read_cstr(&mut self) -> Result<&'de str, BsonError> { |
| 47 | + let raw = CStr::from_bytes_until_nul(self.remaining_input) |
| 48 | + .map_err(|_| self.error(ErrorKind::UnterminatedCString))?; |
| 49 | + let str = raw |
| 50 | + .to_str() |
| 51 | + .map_err(|_| self.error(ErrorKind::InvalidCString))?; |
| 52 | + |
| 53 | + self.advance(str.len() + 1); |
| 54 | + Ok(str) |
| 55 | + } |
| 56 | + |
| 57 | + fn read_int32(&mut self) -> Result<i32, BsonError> { |
| 58 | + let slice = self.advance_checked(4)?; |
| 59 | + Ok(i32::from_le_bytes( |
| 60 | + slice.try_into().expect("should have correct length"), |
| 61 | + )) |
| 62 | + } |
| 63 | + |
| 64 | + fn read_length(&mut self) -> Result<usize, BsonError> { |
| 65 | + let raw = self.read_int32()?; |
| 66 | + u32::try_from(raw) |
| 67 | + .and_then(usize::try_from) |
| 68 | + .map_err(|_| self.error(ErrorKind::InvalidSize)) |
| 69 | + } |
| 70 | + |
| 71 | + fn read_int64(&mut self) -> Result<i64, BsonError> { |
| 72 | + let slice = self.advance_checked(8)?; |
| 73 | + Ok(i64::from_le_bytes( |
| 74 | + slice.try_into().expect("should have correct length"), |
| 75 | + )) |
| 76 | + } |
| 77 | + |
| 78 | + fn read_double(&mut self) -> Result<f64, BsonError> { |
| 79 | + let slice = self.advance_checked(8)?; |
| 80 | + Ok(f64::from_le_bytes( |
| 81 | + slice.try_into().expect("should have correct length"), |
| 82 | + )) |
| 83 | + } |
| 84 | + |
| 85 | + /// Reads a BSON string, `string ::= int32 (byte*) unsigned_byte(0)` |
| 86 | + fn read_string(&mut self) -> Result<&'de str, BsonError> { |
| 87 | + let length_including_null = self.read_length()?; |
| 88 | + let bytes = self.advance_checked(length_including_null)?; |
| 89 | + |
| 90 | + str::from_utf8(&bytes[..length_including_null - 1]) |
| 91 | + .map_err(|_| self.error(ErrorKind::InvalidCString)) |
| 92 | + } |
| 93 | + |
| 94 | + fn read_binary(&mut self) -> Result<(BinarySubtype, &'de [u8]), BsonError> { |
| 95 | + let length = self.read_length()?; |
| 96 | + let subtype = self.advance_byte()?; |
| 97 | + let binary = self.advance_checked(length)?; |
| 98 | + |
| 99 | + Ok((BinarySubtype(subtype), binary)) |
| 100 | + } |
| 101 | + |
| 102 | + fn read_element_type(&mut self) -> Result<ElementType, BsonError> { |
| 103 | + let raw_type = self.advance_byte()? as i8; |
| 104 | + Ok(match raw_type { |
| 105 | + 1 => ElementType::Double, |
| 106 | + 2 => ElementType::String, |
| 107 | + 3 => ElementType::Document, |
| 108 | + 4 => ElementType::Array, |
| 109 | + 5 => ElementType::Binary, |
| 110 | + 6 => ElementType::Undefined, |
| 111 | + 7 => ElementType::ObjectId, |
| 112 | + 8 => ElementType::Boolean, |
| 113 | + 9 => ElementType::DatetimeUtc, |
| 114 | + 10 => ElementType::Null, |
| 115 | + 16 => ElementType::Int32, |
| 116 | + 17 => ElementType::Timestamp, |
| 117 | + 18 => ElementType::Int64, |
| 118 | + _ => return Err(self.error(ErrorKind::UnknownElementType(raw_type))), |
| 119 | + }) |
| 120 | + } |
| 121 | + |
| 122 | + fn subreader(&mut self, len: usize) -> Result<Parser<'de>, BsonError> { |
| 123 | + let current_offset = self.offset; |
| 124 | + let for_sub_reader = self.advance_checked(len)?; |
| 125 | + Ok(Parser { |
| 126 | + offset: current_offset, |
| 127 | + remaining_input: for_sub_reader, |
| 128 | + }) |
| 129 | + } |
| 130 | + |
| 131 | + fn document_scope(&mut self) -> Result<Parser<'de>, BsonError> { |
| 132 | + let total_size = self.read_length()?; |
| 133 | + if total_size < 5 { |
| 134 | + return Err(self.error(ErrorKind::InvalidSize))?; |
| 135 | + } |
| 136 | + |
| 137 | + self.subreader(total_size - 4) |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +#[repr(transparent)] |
| 142 | +struct BinarySubtype(pub u8); |
| 143 | + |
| 144 | +enum ElementType { |
| 145 | + Double = 1, |
| 146 | + String = 2, |
| 147 | + Document = 3, |
| 148 | + Array = 4, |
| 149 | + Binary = 5, |
| 150 | + Undefined = 6, |
| 151 | + ObjectId = 7, |
| 152 | + Boolean = 8, |
| 153 | + DatetimeUtc = 9, |
| 154 | + Null = 10, |
| 155 | + Int32 = 16, |
| 156 | + Timestamp = 17, |
| 157 | + Int64 = 18, |
| 158 | +} |
| 159 | + |
| 160 | +struct Deserializer<'de> { |
| 161 | + parser: Parser<'de>, |
| 162 | + is_outside_of_document: bool, |
| 163 | + pending_value_type: Option<ElementType>, |
| 164 | + consumed_name: bool, |
| 165 | +} |
| 166 | + |
| 167 | +impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { |
| 168 | + type Error = BsonError; |
| 169 | + |
| 170 | + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> |
| 171 | + where |
| 172 | + V: Visitor<'de>, |
| 173 | + { |
| 174 | + // BSON always start with a document, so we need this for the outermost visit_map. |
| 175 | + if self.is_outside_of_document { |
| 176 | + self.parser = self.parser.document_scope()?; |
| 177 | + self.is_outside_of_document = false; |
| 178 | + |
| 179 | + let object = BsonObject { de: self }; |
| 180 | + return visitor.visit_map(object); |
| 181 | + } |
| 182 | + |
| 183 | + if !self.consumed_name { |
| 184 | + self.consumed_name = true; |
| 185 | + // We've read an element type, but not the associated name. Do that now. |
| 186 | + return visitor.visit_borrowed_str(self.parser.read_cstr()?); |
| 187 | + } |
| 188 | + |
| 189 | + if let Some(element_type) = self.pending_value_type.take() { |
| 190 | + return match element_type { |
| 191 | + ElementType::Double => visitor.visit_f64(self.parser.read_double()?), |
| 192 | + ElementType::String => visitor.visit_borrowed_str(self.parser.read_string()?), |
| 193 | + ElementType::Document => { |
| 194 | + let parser = self.parser.document_scope()?; |
| 195 | + let mut deserializer = Deserializer { |
| 196 | + parser, |
| 197 | + is_outside_of_document: false, |
| 198 | + pending_value_type: None, |
| 199 | + consumed_name: false, |
| 200 | + }; |
| 201 | + let object = BsonObject { |
| 202 | + de: &mut deserializer, |
| 203 | + }; |
| 204 | + |
| 205 | + visitor.visit_map(object) |
| 206 | + } |
| 207 | + ElementType::Array => todo!(), |
| 208 | + ElementType::Binary => { |
| 209 | + let (_, bytes) = self.parser.read_binary()?; |
| 210 | + visitor.visit_borrowed_bytes(bytes) |
| 211 | + } |
| 212 | + ElementType::ObjectId => todo!(), |
| 213 | + ElementType::Boolean => { |
| 214 | + let value = self.parser.advance_byte()?; |
| 215 | + visitor.visit_bool(value != 0) |
| 216 | + } |
| 217 | + ElementType::DatetimeUtc => todo!(), |
| 218 | + ElementType::Null | ElementType::Undefined => visitor.visit_none(), |
| 219 | + ElementType::Int32 => visitor.visit_i32(self.parser.read_int32()?), |
| 220 | + ElementType::Int64 => visitor.visit_i64(self.parser.read_int64()?), |
| 221 | + ElementType::Timestamp => todo!(), |
| 222 | + }; |
| 223 | + } |
| 224 | + |
| 225 | + todo!() |
| 226 | + } |
| 227 | + |
| 228 | + forward_to_deserialize_any! { |
| 229 | + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string |
| 230 | + bytes byte_buf option unit unit_struct newtype_struct seq tuple |
| 231 | + tuple_struct map struct enum identifier ignored_any |
| 232 | + } |
| 233 | +} |
| 234 | +struct BsonObject<'a, 'de: 'a> { |
| 235 | + de: &'a mut Deserializer<'de>, |
| 236 | +} |
| 237 | + |
| 238 | +impl<'de, 'a> MapAccess<'de> for BsonObject<'a, 'de> { |
| 239 | + type Error = BsonError; |
| 240 | + |
| 241 | + fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> |
| 242 | + where |
| 243 | + K: DeserializeSeed<'de>, |
| 244 | + { |
| 245 | + if self.de.parser.remaining_input.len() == 1 { |
| 246 | + // Expect trailing 0 for document |
| 247 | + let trailing_zero = self.de.parser.advance_byte()?; |
| 248 | + if trailing_zero != 0 { |
| 249 | + return Err(self.de.parser.error(ErrorKind::InvalidEndOfDocument)); |
| 250 | + } |
| 251 | + |
| 252 | + return Ok(None); |
| 253 | + } |
| 254 | + |
| 255 | + self.de.pending_value_type = Some(self.de.parser.read_element_type()?); |
| 256 | + self.de.consumed_name = false; |
| 257 | + |
| 258 | + Ok(Some(seed.deserialize(&mut *self.de)?)) |
| 259 | + } |
| 260 | + |
| 261 | + fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error> |
| 262 | + where |
| 263 | + V: DeserializeSeed<'de>, |
| 264 | + { |
| 265 | + debug_assert!(self.de.consumed_name); |
| 266 | + debug_assert!(self.de.pending_value_type.is_some()); |
| 267 | + |
| 268 | + seed.deserialize(&mut *self.de) |
| 269 | + } |
| 270 | +} |
| 271 | + |
| 272 | +#[derive(Debug)] |
| 273 | +pub struct BsonError { |
| 274 | + offset: Option<usize>, |
| 275 | + kind: ErrorKind, |
| 276 | +} |
| 277 | + |
| 278 | +#[derive(Debug)] |
| 279 | +enum ErrorKind { |
| 280 | + Custom(String), |
| 281 | + UnknownElementType(i8), |
| 282 | + UnterminatedCString, |
| 283 | + InvalidCString, |
| 284 | + UnexpectedEoF, |
| 285 | + InvalidEndOfDocument, |
| 286 | + InvalidSize, |
| 287 | +} |
| 288 | + |
| 289 | +impl Display for BsonError { |
| 290 | + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
| 291 | + write!(f, "bson error") |
| 292 | + } |
| 293 | +} |
| 294 | + |
| 295 | +impl de::Error for BsonError { |
| 296 | + fn custom<T>(msg: T) -> Self |
| 297 | + where |
| 298 | + T: Display, |
| 299 | + { |
| 300 | + BsonError { |
| 301 | + offset: None, |
| 302 | + kind: ErrorKind::Custom(msg.to_string()), |
| 303 | + } |
| 304 | + } |
| 305 | +} |
| 306 | + |
| 307 | +impl StdError for BsonError {} |
| 308 | + |
| 309 | +pub fn from_bytes<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result<T, BsonError> { |
| 310 | + let parser = Parser { |
| 311 | + offset: 0, |
| 312 | + remaining_input: bytes, |
| 313 | + }; |
| 314 | + let mut deserializer = Deserializer { |
| 315 | + parser, |
| 316 | + is_outside_of_document: true, |
| 317 | + pending_value_type: None, |
| 318 | + consumed_name: false, |
| 319 | + }; |
| 320 | + |
| 321 | + T::deserialize(&mut deserializer) |
| 322 | +} |
| 323 | + |
| 324 | +#[cfg(feature = "std")] |
| 325 | +#[cfg(test)] |
| 326 | +mod test { |
| 327 | + extern crate std; |
| 328 | + use super::*; |
| 329 | + use bson::{Bson, Document}; |
| 330 | + use serde::de::DeserializeOwned; |
| 331 | + |
| 332 | + use std::vec::Vec; |
| 333 | + use std::*; |
| 334 | + |
| 335 | + #[test] |
| 336 | + fn test_hello_world() { |
| 337 | + let mut bytes: Vec<u8> = std::vec![]; |
| 338 | + let mut doc = Document::new(); |
| 339 | + doc.insert("hello", "world"); |
| 340 | + doc.to_writer(&mut bytes).expect("should serialize"); |
| 341 | + |
| 342 | + #[derive(Deserialize)] |
| 343 | + struct Expected<'a> { |
| 344 | + hello: &'a str, |
| 345 | + } |
| 346 | + |
| 347 | + let expected: Expected = from_bytes(&bytes).expect("should deserialize"); |
| 348 | + assert_eq!(expected.hello, "world"); |
| 349 | + } |
| 350 | +} |
0 commit comments