From 53dc5c3e133a728f260afb35ddfb34c31c6f1819 Mon Sep 17 00:00:00 2001 From: Heath Stewart Date: Thu, 27 Jun 2024 15:19:20 -0700 Subject: [PATCH] WIP: Add Variant derive macro for enums --- Cargo.toml | 3 + sdk/core/Cargo.toml | 3 +- sdk/core_macros/Cargo.toml | 26 ++++++ sdk/core_macros/src/case.rs | 110 +++++++++++++++++++++++ sdk/core_macros/src/lib.rs | 21 +++++ sdk/core_macros/src/symbol.rs | 43 +++++++++ sdk/core_macros/src/variant.rs | 158 +++++++++++++++++++++++++++++++++ sdk/identity/Cargo.toml | 3 +- 8 files changed, 363 insertions(+), 4 deletions(-) create mode 100644 sdk/core_macros/Cargo.toml create mode 100644 sdk/core_macros/src/case.rs create mode 100644 sdk/core_macros/src/lib.rs create mode 100644 sdk/core_macros/src/symbol.rs create mode 100644 sdk/core_macros/src/variant.rs diff --git a/Cargo.toml b/Cargo.toml index 6a3d7a277b..9e072e369a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,9 @@ once_cell = "1.18" openssl = { version = "0.10.46" } paste = "1.0" pin-project = "1.0" +proc-macro2 = { version = "1.0.81", default-features = false } quick-xml = { version = "0.31", features = ["serialize", "serde-types"] } +quote = { version = "1.0.36", default-features = false } rand = "0.8" reqwest = { version = "0.12", features = [ "json", @@ -46,6 +48,7 @@ serde_json = "1.0" serde_test = "1" serial_test = "3.0" sha2 = { version = "0.10" } +syn = { version = "2.0.60", default-features = false } thiserror = "1.0" time = { version = "0.3.10", features = [ "serde-well-known", diff --git a/sdk/core/Cargo.toml b/sdk/core/Cargo.toml index 279a6fa774..0be5d19b8d 100644 --- a/sdk/core/Cargo.toml +++ b/sdk/core/Cargo.toml @@ -2,13 +2,12 @@ name = "azure_core" version = "0.19.0" description = "Rust wrappers around Microsoft Azure REST APIs - Core crate" -readme = "README.md" authors.workspace = true license.workspace = true repository.workspace = true homepage = "https://github.com/azure/azure-sdk-for-rust" documentation = "https://docs.rs/azure_core" -keywords = ["sdk", "azure", "rest", "iot", "cloud"] +keywords = ["sdk", "azure", "rest", "cloud"] categories = ["api-bindings"] edition.workspace = true rust-version.workspace = true diff --git a/sdk/core_macros/Cargo.toml b/sdk/core_macros/Cargo.toml new file mode 100644 index 0000000000..e4f2a30fb7 --- /dev/null +++ b/sdk/core_macros/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "azure_core_macros" +version = "0.1.0" +description = "Rust wrappers around Microsoft Azure REST APIs - Core crate" +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage = "https://github.com/azure/azure-sdk-for-rust" +documentation = "https://docs.rs/azure_core_macros" +keywords = ["sdk", "azure", "rest", "cloud"] +categories = ["api-bindings"] +edition.workspace = true +rust-version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = { workspace = true, features = ["proc-macro"] } +quote = { workspace = true, features = ["proc-macro"] } +syn = { workspace = true, features = [ + "derive", + "parsing", + "printing", + "proc-macro", +] } diff --git a/sdk/core_macros/src/case.rs b/sdk/core_macros/src/case.rs new file mode 100644 index 0000000000..eed0a6696c --- /dev/null +++ b/sdk/core_macros/src/case.rs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::fmt; + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Case { + None, + PascalCase, + CamelCase, + SnakeCase, + Lowercase, + Uppercase, +} + +static CASES: &[(&str, Case)] = &[ + ("PascalCase", Case::PascalCase), + ("camelCase", Case::CamelCase), + ("snake_case", Case::SnakeCase), + ("lowercase", Case::Lowercase), + ("UPPERCASE", Case::Uppercase), +]; + +impl Case { + pub fn from_str<'a>(value: &'a str) -> Result> { + for (name, case) in CASES { + if value == *name { + return Ok(*case); + } + } + + Err(ParseError { case: value }) + } + + pub fn rename(self, variant: &str) -> String { + match self { + // Assumes variants are already PascalCase. + Case::None | Case::PascalCase => variant.to_owned(), + Case::CamelCase => variant[..1].to_ascii_lowercase() + &variant[1..], + Case::SnakeCase => { + let mut name = String::new(); + for (i, ch) in variant.char_indices() { + if i > 0 && ch.is_ascii_uppercase() { + name.push('_'); + } + name.push(ch.to_ascii_lowercase()); + } + name + } + Case::Lowercase => variant.to_ascii_lowercase(), + Case::Uppercase => variant.to_ascii_uppercase(), + } + } +} + +impl Default for Case { + fn default() -> Self { + Self::None + } +} + +pub struct ParseError<'a> { + case: &'a str, +} + +impl<'a> fmt::Display for ParseError<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!( + "unknown case `rename_all` = {}, expected one of ", + self.case, + ))?; + for (i, (name, ..)) in CASES.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + f.write_str(name)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn from_str() { + let err = Case::from_str("other").unwrap_err(); + assert!(err.to_string().starts_with( + "unknown case `rename_all` = other, expected one of PascalCase, camelCase" + )); + } + + #[test] + fn rename_all() { + for &(original, pascal_case, camel_case, snake_case, lowercase, uppercase) in &[ + ( + "VarName", "VarName", "varName", "var_name", "varname", "VARNAME", + ), + ("Base64", "Base64", "base64", "base64", "base64", "BASE64"), + ] { + assert_eq!(Case::None.rename(original), original); + assert_eq!(Case::PascalCase.rename(original), pascal_case); + assert_eq!(Case::CamelCase.rename(original), camel_case); + assert_eq!(Case::SnakeCase.rename(original), snake_case); + assert_eq!(Case::Lowercase.rename(original), lowercase); + assert_eq!(Case::Uppercase.rename(original), uppercase); + } + } +} diff --git a/sdk/core_macros/src/lib.rs b/sdk/core_macros/src/lib.rs new file mode 100644 index 0000000000..bb24431d43 --- /dev/null +++ b/sdk/core_macros/src/lib.rs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +extern crate proc_macro2; +extern crate quote; +extern crate syn; + +use proc_macro::TokenStream; +use syn::{parse_macro_input, DeriveInput}; + +mod case; +mod symbol; +mod variant; + +#[proc_macro_derive(Variant, attributes(variant))] +pub fn variant_derive(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + variant::expand_derive_variant(&input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/sdk/core_macros/src/symbol.rs b/sdk/core_macros/src/symbol.rs new file mode 100644 index 0000000000..a0e65e0ff8 --- /dev/null +++ b/sdk/core_macros/src/symbol.rs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::fmt; + +use syn::{Ident, Path}; + +#[derive(Copy, Clone)] +pub struct Symbol(&'static str); + +pub const RENAME: Symbol = Symbol("rename"); +pub const RENAME_ALL: Symbol = Symbol("rename_all"); +pub const VARIANT: Symbol = Symbol("variant"); + +impl PartialEq for Ident { + fn eq(&self, other: &Symbol) -> bool { + *self == other.0 + } +} + +impl<'a> PartialEq for &'a Ident { + fn eq(&self, other: &Symbol) -> bool { + *self == other.0 + } +} + +impl PartialEq for Path { + fn eq(&self, other: &Symbol) -> bool { + self.is_ident(other.0) + } +} + +impl<'a> PartialEq for &'a Path { + fn eq(&self, other: &Symbol) -> bool { + self.is_ident(other.0) + } +} + +impl fmt::Display for Symbol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.0) + } +} diff --git a/sdk/core_macros/src/variant.rs b/sdk/core_macros/src/variant.rs new file mode 100644 index 0000000000..e1b2f50d5c --- /dev/null +++ b/sdk/core_macros/src/variant.rs @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use crate::{case::Case, symbol::*}; +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{punctuated::Punctuated, spanned::Spanned as _, DeriveInput, Token}; + +pub fn expand_derive_variant(input: &DeriveInput) -> syn::Result { + let name = &input.ident; + let case = get_case(&input.attrs, Case::default())?; + let variants = get_variants(input.span(), &input.data, case)?; + + todo!(); +} + +fn get_case(attrs: &Vec, default: Case) -> syn::Result { + let mut case = default; + + for attr in attrs { + if attr.path() != VARIANT { + continue; + } + + attr.parse_nested_meta(|meta| { + if meta.path == RENAME_ALL { + let value = get_attr_string(&meta)?; + case = Case::from_str(&value) + .map_err(|err| syn::Error::new(attr.span(), err.to_string()))?; + } + + Ok(()) + })?; + } + + Ok(case) +} + +fn get_variants( + span: Span, + data: &syn::Data, + default: Case, +) -> syn::Result> { + let mut variants: Vec<_> = Vec::new(); + match *data { + syn::Data::Enum(ref data) => { + for var in &data.variants { + let mut value = default.rename(&var.ident.to_string()); + + for attr in &var.attrs { + if attr.path() != VARIANT { + continue; + } + + attr.parse_nested_meta(|meta| { + if meta.path == RENAME { + value = get_attr_string(&meta)?; + } + Ok(()) + })?; + } + + variants.push((var.ident.clone(), value)); + } + } + _ => { + return Err(syn::Error::new(span, "can only derive `Variant` on enums")); + } + } + + Ok(variants) +} + +fn get_attr_string(meta: &syn::meta::ParseNestedMeta) -> syn::Result { + let expr: syn::Expr = meta.value()?.parse()?; + let mut value: &syn::Expr = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + + let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit), + .. + }) = value + else { + return Err(syn::Error::new(expr.span(), "expected `str` literal")); + }; + + let suffix = lit.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new( + expr.span(), + format!("unexpected suffix {suffix} on `str` literal"), + )); + } + + return Ok(lit.value()); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn get_case_default() { + let tokens = quote! { + enum Sut { + Foo, + Bar, + } + }; + + let tokens = syn::parse2::(tokens).expect("variant token stream"); + let case = get_case(&tokens.attrs, Case::default()).expect("case attribute"); + + assert_eq!(Case::None, case); + } + + #[test] + fn get_case_lowercase() { + let tokens = quote! { + #[variant(rename_all = "lowercase")] + enum Sut { + Foo, + Bar, + } + }; + + let tokens = syn::parse2::(tokens).expect("variant token stream"); + let case = get_case(&tokens.attrs, Case::default()).expect("case attribute"); + + assert_eq!(Case::Lowercase, case); + } + + #[test] + fn get_variants_names() { + let tokens = quote! { + #[variant(rename_all = "lowercase")] + enum Sut { + Foo, + #[variant(rename = "BAZ")] + Bar, + } + }; + + let tokens = syn::parse2::(tokens).expect("variant token stream"); + let variants = get_variants(tokens.span(), &tokens.data, Case::Lowercase) + .expect("case attribute and variants"); + + assert_eq!(variants.len(), 2); + + assert_eq!(variants[0].0.to_string(), "Foo"); + assert_eq!(variants[0].1, "foo"); + + assert_eq!(variants[1].0.to_string(), "Bar"); + assert_eq!(variants[1].1, "BAZ"); + } +} diff --git a/sdk/identity/Cargo.toml b/sdk/identity/Cargo.toml index a0b8dbd714..675fe450cc 100644 --- a/sdk/identity/Cargo.toml +++ b/sdk/identity/Cargo.toml @@ -2,13 +2,12 @@ name = "azure_identity" version = "0.19.0" description = "Rust wrappers around Microsoft Azure REST APIs - Azure identity helper crate" -readme = "README.md" authors.workspace = true license.workspace = true repository.workspace = true homepage = "https://github.com/azure/azure-sdk-for-rust" documentation = "https://docs.rs/azure_identity" -keywords = ["sdk", "azure", "rest", "iot", "cloud"] +keywords = ["sdk", "azure", "rest", "cloud"] categories = ["api-bindings"] edition.workspace = true