Skip to content

Commit

Permalink
- add: custom encoder and decoder for main struct and restricted types
Browse files Browse the repository at this point in the history
  • Loading branch information
RavenX8 committed Dec 9, 2024
1 parent 5e252df commit c99c4c2
Showing 1 changed file with 195 additions and 9 deletions.
204 changes: 195 additions & 9 deletions generator/src/codegen/rust/codegen_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ::flat_ast::*;
use std::io::{Result, Write};
use ::heck::*;
use std::collections::HashMap;
use flat_ast::RestrictionContent::{Enumeration, Length, MaxValue, MinValue};

pub (crate) struct CodeSourceGenerator<'a, W: Write + 'a> {
writer: &'a mut ::writer::Writer<W>,
Expand Down Expand Up @@ -32,7 +33,9 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> {
pub fn generate(&mut self, packet: &Packet) -> Result<()> {
let version = self.version.clone();
cg!(self, "/* Generated with IDL v{} */\n", version);
cg!(self, r#"use bincode::{{Encode, Decode}};"#);
cg!(self, r#"use bincode::{{Encode, Decode, enc::Encoder, de::Decoder, error::DecodeError}};"#);
cg!(self, r#"use bincode::de::read::Reader;"#);
cg!(self, r#"use bincode::enc::write::Writer;"#);
cg!(self, r#"use crate::packet::PacketPayload;"#);

let mut iserialize: HashMap<String, String> = HashMap::new();
Expand Down Expand Up @@ -70,7 +73,7 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> {

cg!(self);

cg!(self, r#"#[derive(Debug, Encode, Decode)]"#);
cg!(self, r#"#[derive(Debug)]"#);
cg!(self, "pub struct {} {{", packet.class_name().to_upper_camel_case());
self.indent();
for content in packet.contents() {
Expand All @@ -85,6 +88,62 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> {

cg!(self);
cg!(self, "impl PacketPayload for {} {{}}", packet.class_name().to_upper_camel_case());
cg!(self);
cg!(self, "impl Encode for {} {{", packet.class_name().to_upper_camel_case());
self.indent();
cg!(self, "fn encode<E: Encoder>(&self, encoder: &mut E) -> std::result::Result<(), bincode::error::EncodeError> {{");
self.indent();
for content in packet.contents() {
use self::PacketContent::*;
match content {
Element(ref elem) => {
let name = rename_if_reserved(elem.name());
cg!(self, "self.{}.encode(encoder)?;", name);
},
_ => {}
};
}
cg!(self, "Ok(())");
self.dedent();
cg!(self, "}}");
self.dedent();
cg!(self, "}}");

cg!(self);
cg!(self, "impl Decode for {} {{", packet.class_name().to_upper_camel_case());
self.indent();
cg!(self, "fn decode<D: Decoder>(decoder: &mut D) -> std::result::Result<Self, bincode::error::DecodeError> {{");
self.indent();
for content in packet.contents() {
use self::PacketContent::*;
match content {
Element(ref elem) => {
let name = rename_if_reserved(elem.name());
let trimmed_type = elem.type_().trim().to_string();
let rust_type = iserialize.get(elem.type_().trim()).unwrap_or_else(|| {
debug!(r#"Type "{}" not found, outputting anyway"#, elem.type_());
&trimmed_type
});
cg!(self, "let {} = {}::decode(decoder)?;", name, rust_type);
},
_ => {}
};
}
cg!(self, "Ok({} {{", packet.class_name().to_upper_camel_case());
for content in packet.contents() {
use self::PacketContent::*;
match content {
Element(ref elem) => {
cg!(self, "{},", rename_if_reserved(elem.name()));
},
_ => {}
};
}
cg!(self, "}})");
self.dedent();
cg!(self, "}}");
self.dedent();
cg!(self, "}}");


Ok(())
Expand Down Expand Up @@ -216,7 +275,7 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> {
// };
let name = rename_if_reserved(elem.name());
// cg!(self, "{}: {}{}{},", elem.name(), type_, bits, default);
cg!(self, "{}: {}{},", name, type_, bits);
cg!(self, "pub(crate) {}: {}{},", name, type_, bits);
Ok(())
}

Expand All @@ -235,8 +294,8 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> {

if is_enum {
cg!(self, r#"#[repr({})]"#, rust_type);
cg!(self, r#"#[derive(Debug, Encode, Decode)]"#);
cg!(self, "enum {} {{", name.to_upper_camel_case());
cg!(self, r#"#[derive(Debug, Clone)]"#);
cg!(self, "pub(crate) enum {} {{", name.to_upper_camel_case());
self.indent();
for content in restrict.contents() {
if let Enumeration(en) = content {
Expand All @@ -245,14 +304,142 @@ impl<'a, W: Write> CodeSourceGenerator<'a, W> {
}
}
} else {
cg!(self, r#"#[derive(Debug, Encode, Decode)]"#);
cg!(self, "struct {} {{", name.to_upper_camel_case());
cg!(self, r#"#[derive(Debug)]"#);
cg!(self, "pub struct {} {{", name.to_upper_camel_case());
self.indent();
cg!(self, "{}: {},", name.to_string().to_snake_case(), rust_type);
cg!(self, "pub(crate) {}: {},", name.to_string().to_snake_case(), rust_type);
}

self.dedent();
cg!(self, "}}");

cg!(self);
self.restrict_encode(&restrict, name, iserialize)?;
cg!(self);
self.restrict_decode(&restrict, name, iserialize)?;
Ok(())
}

fn restrict_encode(&mut self, restrict: &Restriction, name: &str, iserialize: &HashMap<String, String>) -> Result<()> {
let is_enum = restrict.contents().iter().find(|content| match content {
Enumeration(_) => true,
_ => false
}).is_some();
let trimmed_type = restrict.base().trim().to_string();
let rust_type = iserialize.get(restrict.base().trim()).unwrap_or_else(|| {
debug!(r#"Type "{}" not found, outputting anyway"#, restrict.base());
&trimmed_type
});

cg!(self, "impl Encode for {} {{", name.to_upper_camel_case());
self.indent();
cg!(self, "fn encode<E: Encoder>(&self, encoder: &mut E) -> std::result::Result<(), bincode::error::EncodeError> {{");
self.indent();
if is_enum {
cg!(self, "encoder.writer().write(&[self.clone() as u8]).map_err(Into::into)");
} else {
let data = name.to_string().to_snake_case();
cg!(self, "let bytes = self.{}.as_bytes();", data);
for content in restrict.contents() {
match content {
Length(l) => {
cg!(self, "let fixed_length = {};", l);
cg!(self, "if bytes.len() > fixed_length {{");
self.indent();
cg!(self, "return Err(bincode::error::EncodeError::OtherString(format!(");
cg!(self, "\"{} length exceeds fixed size: {{}} > {{}}\", bytes.len(), fixed_length)));", data);
self.dedent();
cg!(self, "}}");
cg!(self, "encoder.writer().write(bytes)?;");
cg!(self, "encoder.writer().write(&vec![0; fixed_length - bytes.len()])?;");
cg!(self, "Ok(())");
},
MinValue(v) => {

},
MaxValue(v) => {

},
_ => panic!("enumeration in restrict when there shouldn't be one")
}
}
}
self.dedent();
cg!(self, "}}");
self.dedent();
cg!(self, "}}");

Ok(())
}

fn restrict_decode(&mut self, restrict: &Restriction, name: &str, iserialize: &HashMap<String, String>) -> Result<()> {
let is_enum = restrict.contents().iter().find(|content| match content {
Enumeration(_) => true,
_ => false
}).is_some();
let trimmed_type = restrict.base().trim().to_string();
let rust_type = iserialize.get(restrict.base().trim()).unwrap_or_else(|| {
debug!(r#"Type "{}" not found, outputting anyway"#, restrict.base());
&trimmed_type
});

cg!(self, "impl Decode for {} {{", name.to_upper_camel_case());
self.indent();
cg!(self, "fn decode<D: Decoder>(decoder: &mut D) -> std::result::Result<Self, bincode::error::DecodeError> {{");
self.indent();
if is_enum {
cg!(self, "let value = {}::decode(decoder)?;", rust_type);
cg!(self, "match value {{");
self.indent();
for content in restrict.contents() {
if let Enumeration(en) = content {
cg!(self, "{} => Ok({}::{}),", en.id(), name.to_upper_camel_case(), en.value().to_upper_camel_case());
}
}
cg!(self, "_ => Err(bincode::error::DecodeError::OtherString(format!(\"Invalid value for {}: {{}}\", value))),", name.to_upper_camel_case());
self.dedent();
cg!(self, "}}");
} else {
let data = name.to_string().to_snake_case();
let mut fixed_length = 64;
let mut minValueCheck = String::new();
let mut maxValueCheck = String::new();
for content in restrict.contents() {
match content {
Length(l) => {
fixed_length = *l;
},
MinValue(v) => {
minValueCheck = format!("if {} < {} {{Err(bincode::error::DecodeError::OtherString(format!(\"Invalid value for {}: {{}} < {{}}\", {}, {})))}}", data, v, data, data, v).into();
},
MaxValue(v) => {
maxValueCheck = format!("if {} > {} {{Err(bincode::error::DecodeError::OtherString(format!(\"Invalid value for {}: {{}} > {{}}\", {}, {})))}}", data, v, data, data, v).into();
},
_ => panic!("enumeration in restrict when there shouldn't be one")
}
}

if rust_type == "String" {
cg!(self, "let mut buffer = vec![0u8; {}];", fixed_length);
cg!(self, "decoder.reader().read(&mut buffer)?;");
cg!(self, "let {} = {}::from_utf8(buffer)", data, rust_type);
cg!(self, ".map_err(|e| DecodeError::OtherString(format!(\"Invalid UTF-8: {{}}\", e)))?");
cg!(self, ".trim_end_matches('\\0')");
cg!(self, ".to_string();");
} else {
cg!(self, "let {} = {}::decode(buffer)?;", data, rust_type);

cg!(self, "{}", minValueCheck);

cg!(self, "{}", maxValueCheck);
}
cg!(self, "Ok({} {{ {} }})", name.to_upper_camel_case(), data);
}
self.dedent();
cg!(self, "}}");
self.dedent();
cg!(self, "}}");

Ok(())
}
}
Expand All @@ -273,4 +460,3 @@ fn rename_if_reserved(name: &str) -> String {
name.to_string().to_snake_case()
}
}

0 comments on commit c99c4c2

Please sign in to comment.