Skip to content

Commit

Permalink
Reorganize code
Browse files Browse the repository at this point in the history
Prepare for a future PR where the macros crate needs to handle multiple different derives, not just one.
  • Loading branch information
adamchalmers committed Dec 21, 2023
1 parent f4fc48d commit 03c36e6
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 428 deletions.
374 changes: 374 additions & 0 deletions execution-plan-macros/src/derive_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, quote_spanned};
use syn::{spanned::Spanned, DataEnum, DeriveInput, Fields, Ident};

use crate::helpers::remove_generics_defaults;

pub(crate) fn impl_derive_value(input: DeriveInput) -> TokenStream2 {
// Where in the input source code is this type defined?
let span = input.span();
// Name of type that is deriving Value
let name = input.ident;
// Any generics defined on the type deriving Value.
let generics = input.generics;
match input.data {
syn::Data::Struct(data) => impl_value_on_struct(span, name, data, generics),
syn::Data::Enum(data) => impl_value_on_enum(name, data, generics),
syn::Data::Union(_) => quote_spanned! {span =>
compile_error!("Value cannot be implemented on a union type")
},
}
}

fn impl_value_on_enum(
name: proc_macro2::Ident,
data: syn::DataEnum,
generics: syn::Generics,
) -> proc_macro2::TokenStream {
// First build fragments of the AST, then we'll combine them into a final output below.
// Build the arms of the `match` statements we'll use below.
let into_parts_match_each_variant = into_parts_match_arms(&data, &name);
let from_parts_match_each_variant = from_parts_match_arms(&data);
let generics_without_defaults = remove_generics_defaults(generics.clone());
let where_clause = generics.where_clause;

// Final return value: the generated Rust code to implement the trait.
// This uses the fragments above, interpolating them into the final outputted code.
quote! {
impl #generics_without_defaults kittycad_execution_plan_traits::Value for #name #generics_without_defaults
#where_clause
{
fn into_parts(self) -> Vec<kittycad_execution_plan_traits::Primitive> {
match self {
#(#into_parts_match_each_variant)*
}
}

fn from_parts<I>(values: &mut I) -> Result<Self, kittycad_execution_plan_traits::MemoryError>
where
I: Iterator<Item = Option<kittycad_execution_plan_traits::Primitive>>,
{
let variant_name = String::from_parts(values)?;
match variant_name.as_str() {
#(#from_parts_match_each_variant)*
other => Err(kittycad_execution_plan_traits::MemoryError::InvalidEnumVariant{
expected_type: stringify!(#name).to_owned(),
actual: other.to_owned(),
})
}
}
}
}
}

// Used in `from_parts()`
// This generates one match arm for each variant of the enum on which `trait Value` is being derived.
// Each match arm will call `from_parts()` recursively on each field of the enum variant,
// then reconstruct the enum from those parts.
fn from_parts_match_arms(data: &DataEnum) -> Vec<TokenStream2> {
data.variants
.iter()
.map(|variant| {
let variant_name = &variant.ident;
match &variant.fields {
// Variant with named fields, like
// ```
// enum MyEnum {
// Extrude{direction: Point3d, distance: f64},
// }
// ```
Fields::Named(expr) => {
let (field_idents, field_types): (Vec<_>, Vec<_>) = expr
.named
.iter()
.filter_map(|named| named.ident.as_ref().map(|id| (id, remove_generics(named.ty.clone()))))
.unzip();
let rhs = quote_spanned! {expr.span()=>
#(let #field_idents = #field_types::from_parts(values)?;)*
Ok(Self::#variant_name{ #(#field_idents),* })
};
quote_spanned! {variant.span() =>
stringify!(#variant_name) => {
#rhs
}
}
}
// Variant with unnamed fields (i.e. fields referenced by position, not name), like
// ```
// enum MyEnum {
// Extrude(f64),
// }
// ```
Fields::Unnamed(expr) => {
// The fields don't have built-in names, but we still need to choose identifiers
// for the variables we're going to match them into.
// Something like MyVariant(field0, field1) => {...}
let (field_idents, field_types): (Vec<_>, Vec<_>) = expr
.unnamed
.iter()
.enumerate()
.map(|(i, field)| (Ident::new(&format!("field{i}"), field.span()), &field.ty))
.unzip();
let rhs = quote_spanned! {expr.span()=>
#(let #field_idents = #field_types::from_parts(values)?;)*
Ok(Self::#variant_name(#(#field_idents),* ))
};
quote_spanned! {expr.span() =>
stringify!(#variant_name) => {
#rhs
}
}
}
// Variant with no fields (or, equivalently, where the fields are () aka the unit type), like
// ```
// enum MyEnum {
// Extrude,
// }
// ```
Fields::Unit => {
quote_spanned! {variant.span()=>
stringify!(#variant_name) => {
Ok(Self::#variant_name)
}
}
}
}
})
.collect()
}

// Used in `into_parts()`
// This generates one match arm for each variant of the enum on which `trait Value` is being derived.
// Each match arm will call `into_parts()` recursively on each field of the enum variant.
fn into_parts_match_arms(data: &DataEnum, name: &proc_macro2::Ident) -> Vec<TokenStream2> {
data.variants
.iter()
.map(|variant| {
let variant_name = &variant.ident;
let fields = &variant.fields;
let (lhs, rhs) = match fields {
// Variant with named fields, like
// ```
// enum MyEnum {
// Extrude{direction: Point3d, distance: f64},
// }
// ```
Fields::Named(expr) => {
let field_idents: Vec<_> = expr.named.iter().filter_map(|name| name.ident.as_ref()).collect();
(
quote_spanned! {expr.span()=>
#name::#variant_name{#(#field_idents),*}
},
quote_spanned! {expr.span()=>
let mut parts = Vec::new();
let tag = stringify!(#variant_name).to_owned();
parts.push(kittycad_execution_plan_traits::Primitive::from(tag));
#(parts.extend(#field_idents.into_parts());)*
parts
},
)
}
// Variant with unnamed fields (i.e. fields referenced by position, not name), like
// ```
// enum MyEnum {
// Extrude(f64),
// }
// ```
Fields::Unnamed(expr) => {
// The fields don't have built-in names, but we still need to choose identifiers
// for the variables we're going to match them into.
// Something like MyVariant(field0, field1) => {...}
let placeholder_field_idents: Vec<_> = expr
.unnamed
.iter()
.enumerate()
.map(|(i, field)| Ident::new(&format!("field{i}"), field.span()))
.collect();
(
quote_spanned! {expr.span() =>
#name::#variant_name(#(#placeholder_field_idents),*)
},
quote_spanned! {expr.span() =>
let mut parts = Vec::new();
let tag = stringify!(#variant_name).to_owned();
parts.push(kittycad_execution_plan_traits::Primitive::from(tag));
#(parts.extend(#placeholder_field_idents.into_parts());)*
parts
},
)
}
// Variant with no fields (or, equivalently, where the fields are () aka the unit type), like
// ```
// enum MyEnum {
// Extrude,
// }
// ```
Fields::Unit => (
quote_spanned! {variant.span() =>
#name::#variant_name
},
quote_spanned! {variant.span()=>
let tag = stringify!(#variant_name).to_owned();
let part = kittycad_execution_plan_traits::Primitive::from(tag);
vec![part]
},
),
};
quote_spanned! {variant.span() =>
#lhs => {
#rhs
}
}
})
.collect()
}

fn remove_generics(mut ty: syn::Type) -> syn::Type {
if let syn::Type::Path(ref mut p) = ty {
for segment in p.path.segments.iter_mut() {
if let syn::PathArguments::AngleBracketed(ref mut _a) = segment.arguments {
segment.arguments = syn::PathArguments::None;
}
}
}
ty
}

fn impl_value_on_struct(
span: Span,
name: proc_macro2::Ident,
data: syn::DataStruct,
generics: syn::Generics,
) -> proc_macro2::TokenStream {
let Fields::Named(ref fields) = data.fields else {
return quote_spanned! {span =>
compile_error!("Value cannot be implemented on a struct with unnamed fields")
};
};

// We're going to construct some fragments of Rust source code, which will get used in the
// final generated code this function returns.

// For every field in the struct, this macro will:
// - In the `into_parts`, extend the Vec of parts with that field, turned into parts.
// - In the `from_parts`, instantiate a Self with a field from that part.
// Step one is to get a list of all named fields in the struct (and their spans):
let field_names: Vec<_> = fields
.named
.iter()
.filter_map(|field| field.ident.as_ref().map(|ident| (ident, field.span())))
.collect();
// Now we can construct those `into_parts` and `from_parts` fragments.
// We take some care to use the span of each `syn::Field` as
// the span of the corresponding `into_parts()` and `from_parts()`
// calls. This way if one of the field types does not
// implement `Value` then the compiler's error message
// underlines which field it is.
let extend_per_field = field_names.iter().map(|(ident, span)| {
quote_spanned! {*span=>
parts.extend(self.#ident.into_parts());
}
});
let instantiate_each_field = field_names.iter().map(|(ident, span)| {
quote_spanned! {*span=>
#ident: kittycad_execution_plan_traits::Value::from_parts(values)?,
}
});

// Handle generics in the original struct.
// Firstly, if the original struct has defaults on its generics, e.g. Point2d<T = f32>,
// don't include those defaults in this macro's output, because the compiler
// complains it's unnecessary and will soon be a compile error.
let generics_without_defaults = remove_generics_defaults(generics.clone());
let where_clause = generics.where_clause;

// Final return value: the generated Rust code to implement the trait.
// This uses the fragments above, interpolating them into the final outputted code.
quote! {
impl #generics_without_defaults kittycad_execution_plan_traits::Value for #name #generics_without_defaults
#where_clause
{
fn into_parts(self) -> Vec<kittycad_execution_plan_traits::Primitive> {
let mut parts = Vec::new();
#(#extend_per_field)*
parts
}

fn from_parts<I>(values: &mut I) -> Result<Self, kittycad_execution_plan_traits::MemoryError>
where
I: Iterator<Item = Option<kittycad_execution_plan_traits::Primitive>>,
{
Ok(Self {
#(#instantiate_each_field)*
})
}
}
}
}

#[cfg(test)]
mod tests {
use anyhow::Result;

use super::*;

#[test]
fn test_enum() {
let input = quote! {
enum FooEnum {
A{x: usize},
B{y: usize},
C(usize, String),
D,
}
};
let input: DeriveInput = syn::parse2(input).unwrap();
let out = impl_derive_value(input);
let formatted = get_text_fmt(&out).unwrap();
insta::assert_snapshot!(formatted);
}

#[test]
fn test_enum_with_generics() {
let input = quote! {
enum Segment {
Line { point: Point3d<f64> }
}
};
let input: DeriveInput = syn::parse2(input).unwrap();
let out = impl_derive_value(input);
let formatted = get_text_fmt(&out).unwrap();
insta::assert_snapshot!(formatted);
}

#[test]
fn test_struct() {
let input = quote! {
struct Line {
point: Point3d<f64>,
tag: Option<String>,
}
};
let input: DeriveInput = syn::parse2(input).unwrap();
let out = impl_derive_value(input);
let formatted = get_text_fmt(&out).unwrap();
insta::assert_snapshot!(formatted);
}

fn clean_text(s: &str) -> String {
// Add newlines after end-braces at <= two levels of indentation.
if cfg!(not(windows)) {
let regex = regex::Regex::new(r"(})(\n\s{0,8}[^} ])").unwrap();
regex.replace_all(s, "$1\n$2").to_string()
} else {
let regex = regex::Regex::new(r"(})(\r\n\s{0,8}[^} ])").unwrap();
regex.replace_all(s, "$1\r\n$2").to_string()
}
}

/// Format a TokenStream as a string and run `rustfmt` on the result.
pub fn get_text_fmt(output: &proc_macro2::TokenStream) -> Result<String> {
let content = rustfmt_wrapper::rustfmt(output).unwrap();
Ok(clean_text(&content))
}
}
13 changes: 13 additions & 0 deletions execution-plan-macros/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/// Remove the defaults from a generic type.
/// For example, turns <T = f32> into <T>.
/// This is useful because defaults like that are valid when declaring a type, but should NOT
/// be included everywhere the type gets used.
/// E.g. you can't say `struct Foo { field: Option<T = f32> }`
pub fn remove_generics_defaults(mut g: syn::Generics) -> syn::Generics {
for generic_param in g.params.iter_mut() {
if let syn::GenericParam::Type(type_param) = generic_param {
type_param.default = None;
}
}
g
}
Loading

0 comments on commit 03c36e6

Please sign in to comment.