diff --git a/derive/src/decode.rs b/derive/src/decode.rs index 38d5565a..fd345fa1 100644 --- a/derive/src/decode.rs +++ b/derive/src/decode.rs @@ -60,7 +60,7 @@ pub fn quote( let recurse = data_variants().enumerate().map(|(i, v)| { let name = &v.ident; - let index = utils::variant_index(v, i); + let index = utils::variant_index(v, i, &mut Vec::new()); let create = create_instance( quote! { #type_name #type_generics :: #name }, diff --git a/derive/src/encode.rs b/derive/src/encode.rs index e1ec680c..d2ed2e2f 100644 --- a/derive/src/encode.rs +++ b/derive/src/encode.rs @@ -314,9 +314,11 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS return quote!() } - let recurse = data_variants().enumerate().map(|(i, f)| { + let mut collected_const_indices: Vec = Vec::new(); + + let recurse: Vec<[TokenStream; 2]> = data_variants().enumerate().map(|(i, f)| { let name = &f.ident; - let index = utils::variant_index(f, i); + let index = utils::variant_index(f, i, &mut collected_const_indices); match f.fields { Fields::Named(ref fields) => { @@ -397,10 +399,34 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS [hinting, encoding] }, } - }); + }).collect(); + + if let Some((duplicate, token_str)) = + utils::find_const_duplicate(&collected_const_indices) + { + let error_message = + format!("index value `{}` is assigned more than once", token_str); + return syn::Error::new_spanned(&duplicate, error_message).to_compile_error(); + } - let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting); - let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding); + let recurse_hinting = recurse.iter().map(|[hinting, _]| hinting.clone()); + let recurse_encoding = recurse.iter().map(|[_, encoding]| encoding.clone()); + + // Runtime check to ensure index attribute variant is within u8 range. + let runtime_checks: Vec<_> = collected_const_indices + .iter() + .enumerate() + .map(|(idx, expr)| { + let check_const = + syn::Ident::new(&format!("CHECK_{}", idx), proc_macro2::Span::call_site()); + quote! { + const #check_const: u8 = #expr; + if #check_const as u32 > 255 { + panic!("Index attribute variant must be in 0..255, found {}", #check_const); + } + } + }) + .collect(); let hinting = quote! { // The variant index uses 1 byte. @@ -411,6 +437,8 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS }; let encoding = quote! { + #( #runtime_checks )* + match *#self_ { #( #recurse_encoding )*, _ => (), diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 5a2fa45d..9615ff5b 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -23,7 +23,7 @@ use proc_macro2::TokenStream; use quote::{ToTokens, quote}; use syn::{ parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput, - Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant, + Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant, Expr, }; fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option @@ -37,32 +37,41 @@ where }) } -/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute +/// Look for a `#[codec(index = $int)]` attribute on a variant. If no attribute /// is found, fall back to the discriminant or just the variant index. -pub fn variant_index(v: &Variant, i: usize) -> TokenStream { - // first look for an attribute - let index = find_meta_item(v.attrs.iter(), |meta| { - if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta { - if nv.path.is_ident("index") { - if let Lit::Int(ref v) = nv.lit { - let byte = v - .base10_parse::() - .expect("Internal error, index attribute must have been checked"); - return Some(byte) - } - } - } - - None - }); - - // then fallback to discriminant or just index - index.map(|i| quote! { #i }).unwrap_or_else(|| { - v.discriminant - .as_ref() - .map(|(_, expr)| quote! { #expr }) - .unwrap_or_else(|| quote! { #i }) - }) +pub fn variant_index( + variant: &Variant, + i: usize, + collected_const_indices: &mut Vec, +) -> TokenStream { + let mut index_option: Option = None; + + for attr in variant.attrs.iter().filter(|attr| attr.path.is_ident("codec")) { + if let Ok(codec_variants) = attr.parse_args::() { + if let Some(codec_index) = codec_variants.index { + match codec_index { + CodecIndex::U8(value) => { + index_option = Some(quote! { #value }); + break; + }, + CodecIndex::ExprConst(expr) => { + collected_const_indices.push(quote! { #expr }); + index_option = Some(quote! { #expr }); + break; + } + } + } + } + } + + // Fallback to discriminant or index + index_option.unwrap_or_else(|| { + variant + .discriminant + .as_ref() + .map(|(_, expr)| quote! { #expr }) + .unwrap_or_else(|| quote! { #i }) + }) } /// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given @@ -369,34 +378,75 @@ fn check_field_attribute(attr: &Attribute) -> syn::Result<()> { } } +pub enum CodecIndex { + U8(u8), + ExprConst(Expr), +} + +struct CodecVariants { + skip: bool, + index: Option, +} + +const INDEX_RANGE_ERROR: &str = "Index attribute variant must be in 0..255"; +const INDEX_TYPE_ERROR: &str = + "Only u8 indices are accepted for attribute variant `#[codec(index = $u8)]`"; +const ATTRIBUTE_ERROR: &str = + "Invalid attribute on variant, only `#[codec(skip)]` and `#[codec(index = $u8)]` are accepted."; + +impl Parse for CodecVariants { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut skip = false; + let mut index = None; + + while !input.is_empty() { + let lookahead = input.lookahead1(); + if lookahead.peek(syn::Ident) { + let ident: syn::Ident = input.parse()?; + if ident == "skip" { + skip = true; + } else if ident == "index" { + input.parse::()?; + if let Ok(lit) = input.parse::() { + let parsed_index = lit + .base10_parse::() + .map_err(|_| syn::Error::new(lit.span(), INDEX_RANGE_ERROR)); + index = Some(CodecIndex::U8(parsed_index?)); + } else { + let expr = input + .parse::() + .map_err(|_| syn::Error::new_spanned(ident, INDEX_TYPE_ERROR))?; + index = Some(CodecIndex::ExprConst(expr)); + } + } else { + return Err(syn::Error::new_spanned(ident, ATTRIBUTE_ERROR)); + } + } else { + return Err(lookahead.error()); + } + } + + Ok(CodecVariants { skip, index }) + } +} + // Ensure a field is decorated only with the following attributes: // * `#[codec(skip)]` // * `#[codec(index = $int)]` fn check_variant_attribute(attr: &Attribute) -> syn::Result<()> { - let variant_error = "Invalid attribute on variant, only `#[codec(skip)]` and \ - `#[codec(index = $u8)]` are accepted."; - if attr.path.is_ident("codec") { - match attr.parse_meta()? { - Meta::List(ref meta_list) if meta_list.nested.len() == 1 => { - match meta_list.nested.first().expect("Just checked that there is one item; qed") { - NestedMeta::Meta(Meta::Path(path)) - if path.get_ident().map_or(false, |i| i == "skip") => - Ok(()), - - NestedMeta::Meta(Meta::NameValue(MetaNameValue { - path, - lit: Lit::Int(lit_int), - .. - })) if path.get_ident().map_or(false, |i| i == "index") => lit_int - .base10_parse::() - .map(|_| ()) - .map_err(|_| syn::Error::new(lit_int.span(), "Index must be in 0..255")), - - elt => Err(syn::Error::new(elt.span(), variant_error)), + match attr.parse_args::() { + Ok(codec_variants) => { + if codec_variants.skip || codec_variants.index.is_some() { + Ok(()) + } else { + Err(syn::Error::new_spanned(attr, ATTRIBUTE_ERROR)) } }, - meta => Err(syn::Error::new(meta.span(), variant_error)), + Err(e) => Err(syn::Error::new_spanned( + attr, + format!("Error checking variant attribute: {}", e), + )), } } else { Ok(()) @@ -451,3 +501,16 @@ pub fn is_transparent(attrs: &[syn::Attribute]) -> bool { // TODO: When migrating to syn 2 the `"(transparent)"` needs to be changed into `"transparent"`. check_repr(attrs, "(transparent)") } + +/// Find a duplicate `TokenStream` in a list of indices. +/// Each `TokenStream` is a constant expression expected to represent an u8 index for an enum variant. +pub fn find_const_duplicate(indices: &[TokenStream]) -> Option<(TokenStream, String)> { + let mut seen = std::collections::HashSet::new(); + for index in indices { + let token_str = index.to_token_stream().to_string(); + if !seen.insert(token_str.clone()) { + return Some((index.clone(), token_str)); + } + } + None +} diff --git a/tests/scale_codec_ui/duplicate_const_expr.rs b/tests/scale_codec_ui/duplicate_const_expr.rs new file mode 100644 index 00000000..59bc8054 --- /dev/null +++ b/tests/scale_codec_ui/duplicate_const_expr.rs @@ -0,0 +1,12 @@ +#[derive(::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +pub enum Enum { + #[codec(index = MY_CONST_INDEX)] + Variant1, + #[codec(index = MY_CONST_INDEX)] + Variant2, +} + +const MY_CONST_INDEX: u8 = 1; + +fn main() {} diff --git a/tests/scale_codec_ui/duplicate_const_expr.stderr b/tests/scale_codec_ui/duplicate_const_expr.stderr new file mode 100644 index 00000000..a6c5820e --- /dev/null +++ b/tests/scale_codec_ui/duplicate_const_expr.stderr @@ -0,0 +1,5 @@ +error: index value `MY_CONST_INDEX` is assigned more than once + --> tests/scale_codec_ui/duplicate_const_expr.rs:6:21 + | +6 | #[codec(index = MY_CONST_INDEX)] + | ^^^^^^^^^^^^^^ diff --git a/tests/scale_codec_ui/invalid_attr_name.rs b/tests/scale_codec_ui/invalid_attr_name.rs new file mode 100644 index 00000000..83708195 --- /dev/null +++ b/tests/scale_codec_ui/invalid_attr_name.rs @@ -0,0 +1,8 @@ +#[derive(::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +pub enum Enum { + #[codec(scale = 1)] + Variant1, +} + +fn main() {} diff --git a/tests/scale_codec_ui/invalid_attr_name.stderr b/tests/scale_codec_ui/invalid_attr_name.stderr new file mode 100644 index 00000000..add5c018 --- /dev/null +++ b/tests/scale_codec_ui/invalid_attr_name.stderr @@ -0,0 +1,5 @@ +error: Error checking variant attribute: Invalid attribute on variant, only `#[codec(skip)]` and `#[codec(index = $u8)]` are accepted. + --> tests/scale_codec_ui/invalid_attr_name.rs:4:5 + | +4 | #[codec(scale = 1)] + | ^^^^^^^^^^^^^^^^^^^ diff --git a/tests/scale_codec_ui/invalid_attr_type.rs b/tests/scale_codec_ui/invalid_attr_type.rs new file mode 100644 index 00000000..5cc0a842 --- /dev/null +++ b/tests/scale_codec_ui/invalid_attr_type.rs @@ -0,0 +1,8 @@ +#[derive(::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +pub enum Enum { + #[codec(index = "invalid")] + Variant1, +} + +fn main() {} diff --git a/tests/scale_codec_ui/invalid_attr_type.stderr b/tests/scale_codec_ui/invalid_attr_type.stderr new file mode 100644 index 00000000..2e18c127 --- /dev/null +++ b/tests/scale_codec_ui/invalid_attr_type.stderr @@ -0,0 +1,5 @@ +error: Error checking variant attribute: Only u8 indices are accepted for attribute variant `#[codec(index = $u8)]` + --> tests/scale_codec_ui/invalid_attr_type.rs:4:5 + | +4 | #[codec(index = "invalid")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/scale_codec_ui/overflowing_index_value.rs b/tests/scale_codec_ui/overflowing_index_value.rs new file mode 100644 index 00000000..2d330a3d --- /dev/null +++ b/tests/scale_codec_ui/overflowing_index_value.rs @@ -0,0 +1,8 @@ +#[derive(::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +pub enum Enum { + #[codec(index = 256)] + Variant1, +} + +fn main() {} diff --git a/tests/scale_codec_ui/overflowing_index_value.stderr b/tests/scale_codec_ui/overflowing_index_value.stderr new file mode 100644 index 00000000..d7c92dba --- /dev/null +++ b/tests/scale_codec_ui/overflowing_index_value.stderr @@ -0,0 +1,5 @@ +error: Error checking variant attribute: Index attribute variant must be in 0..255 + --> tests/scale_codec_ui/overflowing_index_value.rs:4:5 + | +4 | #[codec(index = 256)] + | ^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/variant_number.rs b/tests/variant_number.rs index 54a900d3..80fe6bc1 100644 --- a/tests/variant_number.rs +++ b/tests/variant_number.rs @@ -38,3 +38,38 @@ fn index_attr_variant_counted_and_reused_in_default_index() { assert_eq!(T::A.encode(), vec![1]); assert_eq!(T::B.encode(), vec![1]); } + +#[test] +fn different_const_expr_in_index_attr_variant() { + const MY_CONST_INDEX: u8 = 1; + const ANOTHER_CONST_INDEX: u8 = 2; + + #[derive(DeriveEncode)] + enum T { + #[codec(index = MY_CONST_INDEX)] + A, + B, + #[codec(index = ANOTHER_CONST_INDEX)] + C, + #[codec(index = 3)] + D, + } + + assert_eq!(T::A.encode(), vec![1]); + assert_eq!(T::B.encode(), vec![1]); + assert_eq!(T::C.encode(), vec![2]); + assert_eq!(T::D.encode(), vec![3]); +} + +#[test] +fn complex_const_expr_in_index_attr_variant() { + const MY_CONST_INDEX: u8 = 1; + + #[derive(DeriveEncode)] + enum T { + #[codec(index = MY_CONST_INDEX + 1_u8)] + A, + } + + assert_eq!(T::A.encode(), vec![2]); +}