diff --git a/bitbybit-tests/src/bitfield_tests.rs b/bitbybit-tests/src/bitfield_tests.rs index bcb468c..ea2e27f 100644 --- a/bitbybit-tests/src/bitfield_tests.rs +++ b/bitbybit-tests/src/bitfield_tests.rs @@ -431,10 +431,13 @@ fn builder_available_in_const_context() { a: u16, } - assert_eq!(const { Test::builder().with_a(123).build().raw_value() }, 123); + assert_eq!( + const { Test::builder().with_a(123).build().raw_value() }, + 123 + ); const { let raw = Test::builder().with_a(123).build().raw_value(); - if raw != 123 { + if raw != 123 { panic!("builder didn't build the right value `123`"); } } @@ -2131,6 +2134,118 @@ fn overlapping_fields_fully_covering_range() { let _ = Test::builder().with_b(0).build(); } +#[test] +fn const_generic() { + // test support for arbitrary generics. These aren't used by bitfield itself but are allowed + // so that the user can add special methods + #[bitfield(u16)] + pub struct Test { + #[bits(0..=15, rw)] + b: u16, + } + + impl Test { + fn be() -> bool { + BE + } + } + + assert!(Test::::be()); + assert!(!Test::::be()); + + // builder still works through the generic + let t: Test = Test::::builder().with_b(0x1234).build(); + assert_eq!(t.raw_value(), 0x1234); +} + +#[test] +fn type_generic() { + // type parameters are allowed even when the bitfield itself doesn't use them + trait Tag { + const NAME: &'static str; + } + struct A; + struct B; + impl Tag for A { + const NAME: &'static str = "a"; + } + impl Tag for B { + const NAME: &'static str = "b"; + } + + #[bitfield(u16, default = 0)] + pub struct Test { + #[bits(0..=15, rw)] + b: u16, + } + + impl Test { + fn tag_name() -> &'static str { + T::NAME + } + } + + assert_eq!(Test::::tag_name(), "a"); + assert_eq!(Test::::tag_name(), "b"); + let t: Test = Test::::builder().with_b(7).build(); + assert_eq!(t.raw_value(), 7); + assert_eq!(Test::::ZERO.raw_value(), 0); +} + +#[test] +fn lifetime_generic() { + // lifetime parameters are also supported + #[bitfield(u16)] + pub struct Test<'a> { + #[bits(0..=15, rw)] + b: u16, + } + + impl<'a> Test<'a> { + fn from_slice(_s: &'a [u8]) -> Self { + Self::new_with_raw_value(0xABCD) + } + } + + let data = [0u8; 4]; + let t = Test::from_slice(&data); + assert_eq!(t.raw_value(), 0xABCD); +} + +#[test] +fn generic_with_where_clause() { + // where clauses should be forwarded to the generated struct and impls + trait Kind { + const K: u16; + } + struct X; + impl Kind for X { + const K: u16 = 42; + } + + #[bitfield(u16, default = 0)] + pub struct Test + where + T: Kind, + { + #[bits(0..=15, rw)] + b: u16, + } + + impl Test + where + T: Kind, + { + fn kind() -> u16 { + T::K + } + } + + assert_eq!(Test::::kind(), 42); + let t: Test = Test::::builder().with_b(9).build(); + assert_eq!(t.raw_value(), 9); +} + #[test] fn arbitrary_int_base() { #[bitfield(u20, default = 0)] diff --git a/bitbybit/CHANGELOG.md b/bitbybit/CHANGELOG.md index 06b82b1..93744fe 100644 --- a/bitbybit/CHANGELOG.md +++ b/bitbybit/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## bitbybit 2.1.0 + +- Add support for generics on bitfields - they aren't by the bitfield implementation itself but can be + useful in making bitfield structs more useful + +### Added + ## bitbybit 2.0.0 This version expects arbitrary-int 2.x. diff --git a/bitbybit/src/bitfield/codegen.rs b/bitbybit/src/bitfield/codegen.rs index 6abbd91..a7ca101 100644 --- a/bitbybit/src/bitfield/codegen.rs +++ b/bitbybit/src/bitfield/codegen.rs @@ -6,7 +6,7 @@ use proc_macro2::{Ident, TokenStream}; use quote::{quote, TokenStreamExt as _}; use std::str::FromStr; use std::{collections::HashSet, ops::Range}; -use syn::{LitInt, Type, Visibility}; +use syn::{Generics, LitInt, Type, Visibility}; /// Performs the codegen for the bitfield. /// @@ -21,6 +21,7 @@ pub fn generate( base_data_size: BaseDataSize, internal_base_data_type: &Type, introspect: bool, + phantom_init: &TokenStream, ) -> Vec { let one = syn::parse_str::(format!("1u{}", base_data_size.internal).as_str()) .unwrap_or_else(|_| panic!("bitfield!: Error parsing one literal")); @@ -139,6 +140,7 @@ pub fn generate( assert!(index < #indexed_count); Self { raw_value: #new_raw_value + #phantom_init } } #(#doc_comment)* @@ -155,6 +157,7 @@ pub fn generate( pub const fn #with_name(&self, field_value: #setter_type) -> Self { Self { raw_value: #new_raw_value + #phantom_init } } #(#doc_comment)* @@ -447,9 +450,23 @@ pub fn make_builder( base_data_type: &Ident, base_data_size: BaseDataSize, field_definitions: &[FieldDefinition], + struct_generics: &Generics, + user_ty_args: &[TokenStream], ) -> (TokenStream, Vec) { let builder_struct_name = syn::parse_str::(format!("Partial{}", struct_name).as_str()).unwrap(); + let (_impl_generics, ty_generics, where_clause) = struct_generics.split_for_impl(); + let user_params = &struct_generics.params; + let build_impl_generics = if user_params.is_empty() { + quote!() + } else { + quote!(<#user_params>) + }; + let user_sep = if user_params.is_empty() { + quote!() + } else { + quote!(,) + }; let mut new_with_builder_chain: Vec = Vec::with_capacity(field_definitions.len() + 2); @@ -468,7 +485,7 @@ pub fn make_builder( /// Builder struct for partial initialization of [` #[doc = #struct_name_str] /// `]. - #struct_vis struct #builder_struct_name<#( #params, )*>(#struct_name); + #struct_vis struct #builder_struct_name<#user_params #user_sep #(#params),*>(#struct_name #ty_generics) #where_clause; }); let mut set_params: HashSet> = HashSet::default(); @@ -553,9 +570,9 @@ pub fn make_builder( let doc_comment = &field_definition.doc_comment; new_with_builder_chain.push(quote! { #[allow(non_camel_case_types)] - impl<#( #params, )*> #builder_struct_name<#( #names, )*> { + impl<#user_params #user_sep #( #params ),*> #builder_struct_name<#(#user_ty_args,)* #( #names ),*> #where_clause { #(#doc_comment)* - pub const fn #with_name(&self, __value_mangled: #argument_type) -> #builder_struct_name<#( #result, )*> { + pub const fn #with_name(&self, __value_mangled: #argument_type) -> #builder_struct_name<#(#user_ty_args,)* #( #result ),*> { #builder_struct_name(#value_transform) } } @@ -602,7 +619,7 @@ pub fn make_builder( .collect(); // All non-overlapping fields must be specified for `.build()` to be callable. new_with_builder_chain.push(quote! { - impl #builder_struct_name<#( #set_params, )*> { + impl #build_impl_generics #builder_struct_name<#(#user_ty_args,)* #( #set_params ),*> #where_clause { /// Builds the bitfield from the values passed into this builder. /// /// Every field *must* be set on [` @@ -610,7 +627,7 @@ pub fn make_builder( /// `] to be able to build a [` #[doc = #struct_name_str] /// `]. - pub const fn build(&self) -> #struct_name { + pub const fn build(&self) -> #struct_name #ty_generics { self.0 } } @@ -634,7 +651,7 @@ pub fn make_builder( let result_new_with_constructor = quote! { /// Creates a builder for this bitfield which ensures that all writable fields are /// initialized. - pub const fn builder() -> #builder_struct_name<#( #unset_params, )*> { + pub const fn builder() -> #builder_struct_name<#(#user_ty_args,)* #( #unset_params ),*> { #default } }; diff --git a/bitbybit/src/bitfield/mod.rs b/bitbybit/src/bitfield/mod.rs index 34fd54c..8ecca75 100644 --- a/bitbybit/src/bitfield/mod.rs +++ b/bitbybit/src/bitfield/mod.rs @@ -10,7 +10,7 @@ use std::ops::Range; use std::str::FromStr; use syn::meta::ParseNestedMeta; use syn::LitStr; -use syn::{parse_macro_input, Attribute, Data, DeriveInput, LitInt, Token, Type}; +use syn::{parse_macro_input, Attribute, Data, DeriveInput, GenericParam, LitInt, Token, Type}; /// In the code below, bools are considered to have 0 bits. This lets us distinguish them /// from u1 @@ -278,6 +278,49 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { let struct_name = &input.ident; let struct_vis = &input.vis; let struct_attrs = &input.attrs; + let struct_generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = struct_generics.split_for_impl(); + let user_ty_args: Vec = struct_generics + .params + .iter() + .map(|p| match p { + GenericParam::Type(t) => { + let i = &t.ident; + quote! { #i } + } + GenericParam::Const(c) => { + let i = &c.ident; + quote! { #i } + } + GenericParam::Lifetime(l) => { + let lt = &l.lifetime; + quote! { #lt } + } + }) + .collect(); + let phantom_types: Vec = struct_generics + .params + .iter() + .filter_map(|p| match p { + GenericParam::Type(t) => { + let i = &t.ident; + Some(quote! { fn() -> #i }) + } + GenericParam::Lifetime(l) => { + let lt = &l.lifetime; + Some(quote! { & #lt () }) + } + GenericParam::Const(_) => None, + }) + .collect(); + let (phantom_field, phantom_init) = if phantom_types.is_empty() { + (quote! {}, quote! {}) + } else { + ( + quote! { , _phantom: ::core::marker::PhantomData<( #(#phantom_types,)* )> }, + quote! { , _phantom: ::core::marker::PhantomData }, + ) + }; let fields = match &input.data { Data::Struct(struct_data) => &struct_data.fields, @@ -299,6 +342,7 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { base_data_size, &internal_base_data_type, bitfield_attrs.introspect, + &phantom_init, ); let (default_constructor, default_trait) = if let Some(default_value) = @@ -332,7 +376,7 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { }; let default_trait = quote! { - impl Default for #struct_name { + impl #impl_generics Default for #struct_name #ty_generics #where_clause { fn default() -> Self { Self::DEFAULT } @@ -366,6 +410,8 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { base_data_type, base_data_size, &field_definitions, + struct_generics, + &user_ty_args, ); let raw_value_unwrap = if base_data_size.exposed == base_data_size.internal { @@ -404,21 +450,28 @@ pub fn bitfield(args: TokenStream, input: TokenStream) -> TokenStream { /// No checks are performed on the value, so it is possible to set bits that don't have any /// accessors specified. #[inline] - pub const fn new_with_raw_value(value: #base_data_type) -> #struct_name { - #struct_name { + pub const fn new_with_raw_value(value: #base_data_type) -> Self { + Self { raw_value: #raw_value_unwrap + #phantom_init } } ); let expanded = quote! { - #[derive(Copy, Clone)] #[repr(C)] #( #struct_attrs )* - #struct_vis struct #struct_name { - raw_value: #internal_base_data_type, + #struct_vis struct #struct_name #struct_generics #where_clause { + raw_value: #internal_base_data_type + #phantom_field + } + + impl #impl_generics ::core::marker::Copy for #struct_name #ty_generics #where_clause {} + impl #impl_generics ::core::clone::Clone for #struct_name #ty_generics #where_clause { + #[inline] + fn clone(&self) -> Self { *self } } - impl #struct_name { + impl #impl_generics #struct_name #ty_generics #where_clause { #[doc = #zero_comment] pub const ZERO: Self = Self::new_with_raw_value(#zero);