From c99c4c2a4da49db0b758d214a1d22110bc1c39d8 Mon Sep 17 00:00:00 2001 From: raven <7156279+RavenX8@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:07:40 -0500 Subject: [PATCH] - add: custom encoder and decoder for main struct and restricted types --- generator/src/codegen/rust/codegen_source.rs | 204 ++++++++++++++++++- 1 file changed, 195 insertions(+), 9 deletions(-) diff --git a/generator/src/codegen/rust/codegen_source.rs b/generator/src/codegen/rust/codegen_source.rs index 915f59b..70118ed 100644 --- a/generator/src/codegen/rust/codegen_source.rs +++ b/generator/src/codegen/rust/codegen_source.rs @@ -2,6 +2,7 @@ use ::flat_ast::*; use std::io::{Result, Write}; use ::heck::*; use std::collections::HashMap; +use flat_ast::RestrictionContent::{Enumeration, Length, MaxValue, MinValue}; pub (crate) struct CodeSourceGenerator<'a, W: Write + 'a> { writer: &'a mut ::writer::Writer, @@ -32,7 +33,9 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> { pub fn generate(&mut self, packet: &Packet) -> Result<()> { let version = self.version.clone(); cg!(self, "/* Generated with IDL v{} */\n", version); - cg!(self, r#"use bincode::{{Encode, Decode}};"#); + cg!(self, r#"use bincode::{{Encode, Decode, enc::Encoder, de::Decoder, error::DecodeError}};"#); + cg!(self, r#"use bincode::de::read::Reader;"#); + cg!(self, r#"use bincode::enc::write::Writer;"#); cg!(self, r#"use crate::packet::PacketPayload;"#); let mut iserialize: HashMap = HashMap::new(); @@ -70,7 +73,7 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> { cg!(self); - cg!(self, r#"#[derive(Debug, Encode, Decode)]"#); + cg!(self, r#"#[derive(Debug)]"#); cg!(self, "pub struct {} {{", packet.class_name().to_upper_camel_case()); self.indent(); for content in packet.contents() { @@ -85,6 +88,62 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> { cg!(self); cg!(self, "impl PacketPayload for {} {{}}", packet.class_name().to_upper_camel_case()); + cg!(self); + cg!(self, "impl Encode for {} {{", packet.class_name().to_upper_camel_case()); + self.indent(); + cg!(self, "fn encode(&self, encoder: &mut E) -> std::result::Result<(), bincode::error::EncodeError> {{"); + self.indent(); + for content in packet.contents() { + use self::PacketContent::*; + match content { + Element(ref elem) => { + let name = rename_if_reserved(elem.name()); + cg!(self, "self.{}.encode(encoder)?;", name); + }, + _ => {} + }; + } + cg!(self, "Ok(())"); + self.dedent(); + cg!(self, "}}"); + self.dedent(); + cg!(self, "}}"); + + cg!(self); + cg!(self, "impl Decode for {} {{", packet.class_name().to_upper_camel_case()); + self.indent(); + cg!(self, "fn decode(decoder: &mut D) -> std::result::Result {{"); + self.indent(); + for content in packet.contents() { + use self::PacketContent::*; + match content { + Element(ref elem) => { + let name = rename_if_reserved(elem.name()); + let trimmed_type = elem.type_().trim().to_string(); + let rust_type = iserialize.get(elem.type_().trim()).unwrap_or_else(|| { + debug!(r#"Type "{}" not found, outputting anyway"#, elem.type_()); + &trimmed_type + }); + cg!(self, "let {} = {}::decode(decoder)?;", name, rust_type); + }, + _ => {} + }; + } + cg!(self, "Ok({} {{", packet.class_name().to_upper_camel_case()); + for content in packet.contents() { + use self::PacketContent::*; + match content { + Element(ref elem) => { + cg!(self, "{},", rename_if_reserved(elem.name())); + }, + _ => {} + }; + } + cg!(self, "}})"); + self.dedent(); + cg!(self, "}}"); + self.dedent(); + cg!(self, "}}"); Ok(()) @@ -216,7 +275,7 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> { // }; let name = rename_if_reserved(elem.name()); // cg!(self, "{}: {}{}{},", elem.name(), type_, bits, default); - cg!(self, "{}: {}{},", name, type_, bits); + cg!(self, "pub(crate) {}: {}{},", name, type_, bits); Ok(()) } @@ -235,8 +294,8 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> { if is_enum { cg!(self, r#"#[repr({})]"#, rust_type); - cg!(self, r#"#[derive(Debug, Encode, Decode)]"#); - cg!(self, "enum {} {{", name.to_upper_camel_case()); + cg!(self, r#"#[derive(Debug, Clone)]"#); + cg!(self, "pub(crate) enum {} {{", name.to_upper_camel_case()); self.indent(); for content in restrict.contents() { if let Enumeration(en) = content { @@ -245,14 +304,142 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> { } } } else { - cg!(self, r#"#[derive(Debug, Encode, Decode)]"#); - cg!(self, "struct {} {{", name.to_upper_camel_case()); + cg!(self, r#"#[derive(Debug)]"#); + cg!(self, "pub struct {} {{", name.to_upper_camel_case()); self.indent(); - cg!(self, "{}: {},", name.to_string().to_snake_case(), rust_type); + cg!(self, "pub(crate) {}: {},", name.to_string().to_snake_case(), rust_type); } self.dedent(); cg!(self, "}}"); + + cg!(self); + self.restrict_encode(&restrict, name, iserialize)?; + cg!(self); + self.restrict_decode(&restrict, name, iserialize)?; + Ok(()) + } + + fn restrict_encode(&mut self, restrict: &Restriction, name: &str, iserialize: &HashMap) -> Result<()> { + let is_enum = restrict.contents().iter().find(|content| match content { + Enumeration(_) => true, + _ => false + }).is_some(); + let trimmed_type = restrict.base().trim().to_string(); + let rust_type = iserialize.get(restrict.base().trim()).unwrap_or_else(|| { + debug!(r#"Type "{}" not found, outputting anyway"#, restrict.base()); + &trimmed_type + }); + + cg!(self, "impl Encode for {} {{", name.to_upper_camel_case()); + self.indent(); + cg!(self, "fn encode(&self, encoder: &mut E) -> std::result::Result<(), bincode::error::EncodeError> {{"); + self.indent(); + if is_enum { + cg!(self, "encoder.writer().write(&[self.clone() as u8]).map_err(Into::into)"); + } else { + let data = name.to_string().to_snake_case(); + cg!(self, "let bytes = self.{}.as_bytes();", data); + for content in restrict.contents() { + match content { + Length(l) => { + cg!(self, "let fixed_length = {};", l); + cg!(self, "if bytes.len() > fixed_length {{"); + self.indent(); + cg!(self, "return Err(bincode::error::EncodeError::OtherString(format!("); + cg!(self, "\"{} length exceeds fixed size: {{}} > {{}}\", bytes.len(), fixed_length)));", data); + self.dedent(); + cg!(self, "}}"); + cg!(self, "encoder.writer().write(bytes)?;"); + cg!(self, "encoder.writer().write(&vec![0; fixed_length - bytes.len()])?;"); + cg!(self, "Ok(())"); + }, + MinValue(v) => { + + }, + MaxValue(v) => { + + }, + _ => panic!("enumeration in restrict when there shouldn't be one") + } + } + } + self.dedent(); + cg!(self, "}}"); + self.dedent(); + cg!(self, "}}"); + + Ok(()) + } + + fn restrict_decode(&mut self, restrict: &Restriction, name: &str, iserialize: &HashMap) -> Result<()> { + let is_enum = restrict.contents().iter().find(|content| match content { + Enumeration(_) => true, + _ => false + }).is_some(); + let trimmed_type = restrict.base().trim().to_string(); + let rust_type = iserialize.get(restrict.base().trim()).unwrap_or_else(|| { + debug!(r#"Type "{}" not found, outputting anyway"#, restrict.base()); + &trimmed_type + }); + + cg!(self, "impl Decode for {} {{", name.to_upper_camel_case()); + self.indent(); + cg!(self, "fn decode(decoder: &mut D) -> std::result::Result {{"); + self.indent(); + if is_enum { + cg!(self, "let value = {}::decode(decoder)?;", rust_type); + cg!(self, "match value {{"); + self.indent(); + for content in restrict.contents() { + if let Enumeration(en) = content { + cg!(self, "{} => Ok({}::{}),", en.id(), name.to_upper_camel_case(), en.value().to_upper_camel_case()); + } + } + cg!(self, "_ => Err(bincode::error::DecodeError::OtherString(format!(\"Invalid value for {}: {{}}\", value))),", name.to_upper_camel_case()); + self.dedent(); + cg!(self, "}}"); + } else { + let data = name.to_string().to_snake_case(); + let mut fixed_length = 64; + let mut minValueCheck = String::new(); + let mut maxValueCheck = String::new(); + for content in restrict.contents() { + match content { + Length(l) => { + fixed_length = *l; + }, + MinValue(v) => { + minValueCheck = format!("if {} < {} {{Err(bincode::error::DecodeError::OtherString(format!(\"Invalid value for {}: {{}} < {{}}\", {}, {})))}}", data, v, data, data, v).into(); + }, + MaxValue(v) => { + maxValueCheck = format!("if {} > {} {{Err(bincode::error::DecodeError::OtherString(format!(\"Invalid value for {}: {{}} > {{}}\", {}, {})))}}", data, v, data, data, v).into(); + }, + _ => panic!("enumeration in restrict when there shouldn't be one") + } + } + + if rust_type == "String" { + cg!(self, "let mut buffer = vec![0u8; {}];", fixed_length); + cg!(self, "decoder.reader().read(&mut buffer)?;"); + cg!(self, "let {} = {}::from_utf8(buffer)", data, rust_type); + cg!(self, ".map_err(|e| DecodeError::OtherString(format!(\"Invalid UTF-8: {{}}\", e)))?"); + cg!(self, ".trim_end_matches('\\0')"); + cg!(self, ".to_string();"); + } else { + cg!(self, "let {} = {}::decode(buffer)?;", data, rust_type); + + cg!(self, "{}", minValueCheck); + + cg!(self, "{}", maxValueCheck); + } + cg!(self, "Ok({} {{ {} }})", name.to_upper_camel_case(), data); + } + self.dedent(); + cg!(self, "}}"); + self.dedent(); + cg!(self, "}}"); + Ok(()) } } @@ -273,4 +460,3 @@ fn rename_if_reserved(name: &str) -> String { name.to_string().to_snake_case() } } -