Skip to content

Commit

Permalink
Add a bound to only allow encoding packets to the tokio codec
Browse files Browse the repository at this point in the history
  • Loading branch information
coolreader18 authored and zonyitoo committed Jan 19, 2021
1 parent 33c61e5 commit ef69da8
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ macro_rules! encodable_packet {
}
}

impl $crate::packet::EncodablePacket for $typ {}

impl $typ {
fn encoded_length_noheader(&self) -> u32 {
$($crate::encodable::Encodable::encoded_length(&self.$field) +)*
Expand Down Expand Up @@ -73,6 +75,14 @@ pub mod subscribe;
pub mod unsuback;
pub mod unsubscribe;

/// A trait representing a packet that can be encoded, when passed as `FooPacket` or as
/// `&FooPacket`. Different from [`Encodable`] in that it prevents you from accidentally passing
/// a type intended to be encoded only as a part of a packet and doesn't have a header, e.g.
/// `Vec<u8>`.
pub trait EncodablePacket: Encodable {}

impl<T: EncodablePacket> EncodablePacket for &T {}

/// Methods for encoding and decoding a packet
pub trait Packet: Encodable + fmt::Debug + Sized + 'static {
type Payload: Encodable + Decodable;
Expand All @@ -82,7 +92,7 @@ pub trait Packet: Encodable + fmt::Debug + Sized + 'static {
/// Get a borrow of payload
fn payload_ref(&self) -> &Self::Payload;

/// Deocde packet given a `FixedHeader`
/// Decode packet given a `FixedHeader`
fn decode_packet<R: Read>(reader: &mut R, fixed_header: FixedHeader) -> Result<Self, PacketError<Self>>;
}

Expand Down Expand Up @@ -185,6 +195,8 @@ macro_rules! impl_variable_packet {
}
}

impl EncodablePacket for VariablePacket {}

impl Decodable for VariablePacket {
type Error = VariablePacketError;
type Cond = Option<FixedHeader>;
Expand Down Expand Up @@ -400,7 +412,7 @@ mod tokio_codec {
}
}

impl<T: Encodable> codec::Encoder<T> for MqttEncoder {
impl<T: EncodablePacket> codec::Encoder<T> for MqttEncoder {
type Error = io::Error;
fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> {
dst.reserve(packet.encoded_length() as usize);
Expand Down Expand Up @@ -431,7 +443,7 @@ mod tokio_codec {
}
}

impl<T: Encodable> codec::Encoder<T> for MqttCodec {
impl<T: EncodablePacket> codec::Encoder<T> for MqttCodec {
type Error = io::Error;
#[inline]
fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> {
Expand Down

0 comments on commit ef69da8

Please sign in to comment.