diff --git a/CHANGELOG.md b/CHANGELOG.md index e2ed748..f6c30d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,20 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] - 2026-02-24 + +### Added +- **Complex Enum Variant Support**: Full support for Rust enums with tuple and struct variants + - Simple enums (unit variants only) continue to generate TypeScript string literal unions / `z.enum()` + - Complex enums (tuple or struct variants) now generate TypeScript discriminated unions / `z.discriminatedUnion()` + - Variant payloads are fully typed, including nested structs and generics + - Enum variants are included in the AST cache for efficient subsequent runs + +### Changed +- **Serde Attribute Parsing**: Replaced regex-based parsing with `syn` AST functions + - More robust and accurate handling of `#[serde(...)]` attributes + - Eliminates edge cases caused by pattern matching on raw token strings + ## [0.4.2] - 2026-02-15 ### Fixed diff --git a/Cargo.toml b/Cargo.toml index 74841bc..018e715 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tauri-typegen" -version = "0.4.2" +version = "0.5.0" authors = [ "Stefan Poindl" ] description = "A rust crate that automatically generates TypeScript models and bindings from your Tauri commands" edition = "2021" diff --git a/src/analysis/dependency_graph.rs b/src/analysis/dependency_graph.rs index 34103af..7f93a8f 100644 --- a/src/analysis/dependency_graph.rs +++ b/src/analysis/dependency_graph.rs @@ -251,6 +251,8 @@ mod tests { file_path: file.to_string(), is_enum: false, serde_rename_all: None, + serde_tag: None, + enum_variants: None, } } diff --git a/src/analysis/mod.rs b/src/analysis/mod.rs index 1964c1d..2f81286 100644 --- a/src/analysis/mod.rs +++ b/src/analysis/mod.rs @@ -348,7 +348,11 @@ impl CommandAnalyzer { if item_enum.ident == type_name && self.struct_parser.should_include_enum(item_enum) { - return self.struct_parser.parse_enum(item_enum, file_path); + return self.struct_parser.parse_enum( + item_enum, + file_path, + &mut self.type_resolver, + ); } } _ => {} diff --git a/src/analysis/serde_parser.rs b/src/analysis/serde_parser.rs index 6244ddc..c585a17 100644 --- a/src/analysis/serde_parser.rs +++ b/src/analysis/serde_parser.rs @@ -1,6 +1,5 @@ -use quote::ToTokens; use serde_rename_rule::RenameRule; -use syn::Attribute; +use syn::{Attribute, Expr, ExprLit, Lit}; /// Parser for serde attributes from Rust struct/enum definitions and fields #[derive(Debug)] @@ -11,20 +10,28 @@ impl SerdeParser { Self } - /// Parse struct-level serde attributes (e.g., rename_all) + /// Parse struct-level serde attributes (e.g., rename_all, tag, content) pub fn parse_struct_serde_attrs(&self, attrs: &[Attribute]) -> SerdeStructAttributes { - let mut result = SerdeStructAttributes { rename_all: None }; + let mut result = SerdeStructAttributes { + rename_all: None, + tag: None, + content: None, + }; for attr in attrs { if attr.path().is_ident("serde") { - if let Ok(tokens) = syn::parse2::(attr.meta.to_token_stream()) { - let tokens_str = tokens.tokens.to_string(); - - // Parse rename_all = "convention" - if let Some(convention) = self.parse_rename_all(&tokens_str) { - result.rename_all = Some(convention); + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("rename_all") { + if let Some(value) = parse_string_value(&meta)? { + result.rename_all = RenameRule::from_rename_all_str(&value).ok(); + } + } else if meta.path.is_ident("tag") { + result.tag = parse_string_value(&meta)?; + } else if meta.path.is_ident("content") { + result.content = parse_string_value(&meta)?; } - } + Ok(()) + }); } } @@ -40,76 +47,35 @@ impl SerdeParser { for attr in attrs { if attr.path().is_ident("serde") { - if let Ok(tokens) = syn::parse2::(attr.meta.to_token_stream()) { - let tokens_str = tokens.tokens.to_string(); - - // Check for skip flag - if tokens_str.contains("skip") && !tokens_str.contains("skip_serializing") { + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("rename") { + result.rename = parse_string_value(&meta)?; + } else if meta.path.is_ident("skip") { + // skip is a flag, no value needed result.skip = true; } - - // Parse rename = "value" - if let Some(rename) = self.parse_rename(&tokens_str) { - result.rename = Some(rename); - } - } + // Note: skip_serializing and skip_deserializing are NOT the same as skip + // They only affect one direction, so we don't set the skip flag for them + Ok(()) + }); } } result } +} - /// Parse rename_all value like "camelCase", "snake_case", "PascalCase", etc. to - /// find a matching `serde_rename_rule::RenameRule`. - fn parse_rename_all(&self, tokens: &str) -> Option { - if let Some(start) = tokens.find("rename_all") { - if let Some(eq_pos) = tokens[start..].find('=') { - let after_eq = &tokens[start + eq_pos + 1..].trim_start(); - - // Extract value from quotes - if let Some(quote_start) = after_eq.find('"') { - if let Some(quote_end) = after_eq[quote_start + 1..].find('"') { - let value = &after_eq[quote_start + 1..quote_start + 1 + quote_end]; - - return RenameRule::from_rename_all_str(value).ok(); - } - } - } - } - None - } - - /// Parse rename value from field attribute - fn parse_rename(&self, tokens: &str) -> Option { - // Look for "rename" but not "rename_all" - let mut search_start = 0; - while let Some(pos) = tokens[search_start..].find("rename") { - let abs_pos = search_start + pos; - - // Check if this is followed by "_all" - let after_rename = &tokens[abs_pos + 6..]; - if after_rename.trim_start().starts_with("_all") { - // This is rename_all, skip it - search_start = abs_pos + 10; // Move past "rename_all" - continue; - } - - // This is a plain "rename", extract the value - if let Some(eq_pos) = after_rename.find('=') { - let after_eq = &after_rename[eq_pos + 1..].trim_start(); - - // Extract value from quotes - if let Some(quote_start) = after_eq.find('"') { - if let Some(quote_end) = after_eq[quote_start + 1..].find('"') { - let value = &after_eq[quote_start + 1..quote_start + 1 + quote_end]; - return Some(value.to_string()); - } - } - } - - break; - } - None +/// Parse a string value from a meta item like `name = "value"` +fn parse_string_value(meta: &syn::meta::ParseNestedMeta) -> syn::Result> { + let expr: Expr = meta.value()?.parse()?; + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = expr + { + Ok(Some(lit_str.value())) + } else { + Ok(None) } } @@ -123,6 +89,10 @@ impl Default for SerdeParser { #[derive(Debug, Default, Clone)] pub struct SerdeStructAttributes { pub rename_all: Option, + /// Tag attribute for internally-tagged enum representation: #[serde(tag = "type")] + pub tag: Option, + /// Content attribute for adjacently-tagged enum representation: #[serde(content = "data")] + pub content: Option, } /// Field-level serde attributes @@ -131,110 +101,124 @@ pub struct SerdeFieldAttributes { pub rename: Option, pub skip: bool, } + #[cfg(test)] mod tests { use super::*; use syn::parse_quote; #[test] - fn test_parse_rename_all_camel_case() { + fn test_parse_struct_serde_attrs_with_rename_all_camel_case() { let parser = SerdeParser::new(); - let result = parser.parse_rename_all(r#"rename_all = "camelCase""#); + let attrs: Vec = vec![parse_quote!(#[serde(rename_all = "camelCase")])]; - assert!(result.is_some()); - assert!(matches!(result.unwrap(), RenameRule::CamelCase)); + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(result.rename_all.is_some()); + assert!(matches!(result.rename_all.unwrap(), RenameRule::CamelCase)); } #[test] - fn test_parse_rename_all_snake_case() { + fn test_parse_struct_serde_attrs_with_rename_all_snake_case() { let parser = SerdeParser::new(); - let result = parser.parse_rename_all(r#"rename_all = "snake_case""#); + let attrs: Vec = vec![parse_quote!(#[serde(rename_all = "snake_case")])]; - assert!(result.is_some()); - assert!(matches!(result.unwrap(), RenameRule::SnakeCase)); + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(result.rename_all.is_some()); + assert!(matches!(result.rename_all.unwrap(), RenameRule::SnakeCase)); } #[test] - fn test_parse_rename_all_pascal_case() { + fn test_parse_struct_serde_attrs_with_rename_all_pascal_case() { let parser = SerdeParser::new(); - let result = parser.parse_rename_all(r#"rename_all = "PascalCase""#); + let attrs: Vec = vec![parse_quote!(#[serde(rename_all = "PascalCase")])]; - assert!(result.is_some()); - assert!(matches!(result.unwrap(), RenameRule::PascalCase)); + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(result.rename_all.is_some()); + assert!(matches!(result.rename_all.unwrap(), RenameRule::PascalCase)); } #[test] - fn test_parse_rename_all_screaming_snake_case() { + fn test_parse_struct_serde_attrs_with_rename_all_screaming_snake_case() { let parser = SerdeParser::new(); - let result = parser.parse_rename_all(r#"rename_all = "SCREAMING_SNAKE_CASE""#); + let attrs: Vec = + vec![parse_quote!(#[serde(rename_all = "SCREAMING_SNAKE_CASE")])]; - assert!(result.is_some()); - assert!(matches!(result.unwrap(), RenameRule::ScreamingSnakeCase)); + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(result.rename_all.is_some()); + assert!(matches!( + result.rename_all.unwrap(), + RenameRule::ScreamingSnakeCase + )); } #[test] - fn test_parse_rename_all_kebab_case() { + fn test_parse_struct_serde_attrs_with_rename_all_kebab_case() { let parser = SerdeParser::new(); - let result = parser.parse_rename_all(r#"rename_all = "kebab-case""#); + let attrs: Vec = vec![parse_quote!(#[serde(rename_all = "kebab-case")])]; - assert!(result.is_some()); - assert!(matches!(result.unwrap(), RenameRule::KebabCase)); + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(result.rename_all.is_some()); + assert!(matches!(result.rename_all.unwrap(), RenameRule::KebabCase)); } #[test] - fn test_parse_rename_all_not_present() { + fn test_parse_struct_serde_attrs_no_serde() { let parser = SerdeParser::new(); - let result = parser.parse_rename_all(r#"skip_serializing_if = "Option::is_none""#); + let attrs: Vec = vec![parse_quote!(#[derive(Debug)])]; - assert!(result.is_none()); + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(result.rename_all.is_none()); + assert!(result.tag.is_none()); + assert!(result.content.is_none()); } #[test] - fn test_parse_rename() { + fn test_parse_struct_serde_attrs_with_tag() { let parser = SerdeParser::new(); + let attrs: Vec = vec![parse_quote!(#[serde(tag = "type")])]; - let result = parser.parse_rename(r#"rename = "customName""#); - assert_eq!(result, Some("customName".to_string())); - - let result = parser.parse_rename(r#"rename = "id""#); - assert_eq!(result, Some("id".to_string())); + let result = parser.parse_struct_serde_attrs(&attrs); + assert_eq!(result.tag, Some("type".to_string())); } #[test] - fn test_parse_rename_not_rename_all() { + fn test_parse_struct_serde_attrs_with_custom_tag() { let parser = SerdeParser::new(); + let attrs: Vec = vec![parse_quote!(#[serde(tag = "kind")])]; - // Should not match rename_all - let result = parser.parse_rename(r#"rename_all = "camelCase""#); - assert!(result.is_none()); + let result = parser.parse_struct_serde_attrs(&attrs); + assert_eq!(result.tag, Some("kind".to_string())); } #[test] - fn test_parse_rename_with_rename_all_present() { + fn test_parse_struct_serde_attrs_with_content() { let parser = SerdeParser::new(); + let attrs: Vec = vec![parse_quote!(#[serde(content = "data")])]; - // Should find "rename" even if rename_all is also present - let result = parser.parse_rename(r#"rename_all = "camelCase", rename = "id""#); - assert_eq!(result, Some("id".to_string())); + let result = parser.parse_struct_serde_attrs(&attrs); + assert_eq!(result.content, Some("data".to_string())); } #[test] - fn test_parse_struct_serde_attrs_with_rename_all() { + fn test_parse_struct_serde_attrs_with_tag_and_content() { let parser = SerdeParser::new(); - let attrs: Vec = vec![parse_quote!(#[serde(rename_all = "camelCase")])]; + let attrs: Vec = vec![parse_quote!(#[serde(tag = "kind", content = "data")])]; let result = parser.parse_struct_serde_attrs(&attrs); - assert!(result.rename_all.is_some()); - assert!(matches!(result.rename_all.unwrap(), RenameRule::CamelCase)); + assert_eq!(result.tag, Some("kind".to_string())); + assert_eq!(result.content, Some("data".to_string())); } #[test] - fn test_parse_struct_serde_attrs_no_serde() { + fn test_parse_struct_serde_attrs_with_all_attributes() { let parser = SerdeParser::new(); - let attrs: Vec = vec![parse_quote!(#[derive(Debug)])]; + let attrs: Vec = + vec![parse_quote!(#[serde(rename_all = "camelCase", tag = "type", content = "value")])]; let result = parser.parse_struct_serde_attrs(&attrs); - assert!(result.rename_all.is_none()); + assert!(matches!(result.rename_all, Some(RenameRule::CamelCase))); + assert_eq!(result.tag, Some("type".to_string())); + assert_eq!(result.content, Some("value".to_string())); } #[test] @@ -268,7 +252,17 @@ mod tests { } #[test] - fn test_parse_field_serde_attrs_multiple() { + fn test_parse_field_serde_attrs_skip_deserializing_not_skip() { + let parser = SerdeParser::new(); + let attrs: Vec = vec![parse_quote!(#[serde(skip_deserializing)])]; + + let result = parser.parse_field_serde_attrs(&attrs); + // skip_deserializing should not set skip flag + assert!(!result.skip); + } + + #[test] + fn test_parse_field_serde_attrs_multiple_attributes() { let parser = SerdeParser::new(); let attrs: Vec = vec![ parse_quote!(#[serde(rename = "id")]), @@ -279,6 +273,16 @@ mod tests { assert_eq!(result.rename, Some("id".to_string())); } + #[test] + fn test_parse_field_serde_attrs_rename_and_skip() { + let parser = SerdeParser::new(); + let attrs: Vec = vec![parse_quote!(#[serde(rename = "id", skip)])]; + + let result = parser.parse_field_serde_attrs(&attrs); + assert_eq!(result.rename, Some("id".to_string())); + assert!(result.skip); + } + #[test] fn test_parse_field_serde_attrs_no_serde() { let parser = SerdeParser::new(); @@ -289,10 +293,58 @@ mod tests { assert!(!result.skip); } + #[test] + fn test_parse_field_serde_attrs_empty() { + let parser = SerdeParser::new(); + let attrs: Vec = vec![]; + + let result = parser.parse_field_serde_attrs(&attrs); + assert!(result.rename.is_none()); + assert!(!result.skip); + } + #[test] fn test_default_impl() { - let parser = SerdeParser; - let result = parser.parse_rename(r#"rename = "test""#); - assert_eq!(result, Some("test".to_string())); + let parser = SerdeParser::new(); + let attrs: Vec = vec![parse_quote!(#[serde(rename = "test")])]; + let result = parser.parse_field_serde_attrs(&attrs); + assert_eq!(result.rename, Some("test".to_string())); + } + + #[test] + fn test_parse_struct_serde_attrs_ignores_other_attributes() { + let parser = SerdeParser::new(); + let attrs: Vec = vec![parse_quote!( + #[serde(rename_all = "camelCase", deny_unknown_fields)] + )]; + + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(matches!(result.rename_all, Some(RenameRule::CamelCase))); + // deny_unknown_fields is ignored, no error + } + + #[test] + fn test_parse_field_serde_attrs_ignores_other_attributes() { + let parser = SerdeParser::new(); + let attrs: Vec = vec![ + parse_quote!(#[serde(rename = "id", default, skip_serializing_if = "Option::is_none")]), + ]; + + let result = parser.parse_field_serde_attrs(&attrs); + assert_eq!(result.rename, Some("id".to_string())); + // default and skip_serializing_if are ignored, no error + } + + #[test] + fn test_parse_multiple_serde_attributes() { + let parser = SerdeParser::new(); + let attrs: Vec = vec![ + parse_quote!(#[serde(rename_all = "camelCase")]), + parse_quote!(#[serde(tag = "type")]), + ]; + + let result = parser.parse_struct_serde_attrs(&attrs); + assert!(matches!(result.rename_all, Some(RenameRule::CamelCase))); + assert_eq!(result.tag, Some("type".to_string())); } } diff --git a/src/analysis/struct_parser.rs b/src/analysis/struct_parser.rs index a1e6f59..7440431 100644 --- a/src/analysis/struct_parser.rs +++ b/src/analysis/struct_parser.rs @@ -1,7 +1,7 @@ use crate::analysis::serde_parser::SerdeParser; use crate::analysis::type_resolver::TypeResolver; use crate::analysis::validator_parser::ValidatorParser; -use crate::models::{FieldInfo, StructInfo}; +use crate::models::{EnumVariantInfo, EnumVariantKind, FieldInfo, StructInfo, TypeStructure}; use quote::ToTokens; use std::path::Path; use syn::{Attribute, ItemEnum, ItemStruct, Type, Visibility}; @@ -92,73 +92,103 @@ impl StructParser { file_path: file_path.to_string_lossy().to_string(), is_enum: false, serde_rename_all: struct_serde_attrs.rename_all, + serde_tag: None, + enum_variants: None, }) } /// Parse a Rust enum into StructInfo - pub fn parse_enum(&self, item_enum: &ItemEnum, file_path: &Path) -> Option { + pub fn parse_enum( + &self, + item_enum: &ItemEnum, + file_path: &Path, + type_resolver: &mut TypeResolver, + ) -> Option { // Parse enum-level serde attributes let enum_serde_attrs = self.serde_parser.parse_struct_serde_attrs(&item_enum.attrs); - let fields = item_enum - .variants - .iter() - .map(|variant| { - let variant_name = variant.ident.to_string(); - - // Parse variant-level serde attributes - let variant_serde_attrs = self.serde_parser.parse_field_serde_attrs(&variant.attrs); - - match &variant.fields { - syn::Fields::Unit => { - // Unit variant: Variant - FieldInfo { - name: variant_name, - rust_type: "enum_variant".to_string(), - is_optional: false, - is_public: true, - validator_attributes: None, - serde_rename: variant_serde_attrs.rename, - type_structure: crate::models::TypeStructure::Primitive( - "string".to_string(), - ), - } - } - syn::Fields::Unnamed(_fields_unnamed) => { - // Tuple variant: Variant(T, U) - // Note: Complex enum variants are not fully supported yet - FieldInfo { - name: variant_name, - rust_type: "enum_variant_tuple".to_string(), - is_optional: false, - is_public: true, - validator_attributes: None, - serde_rename: variant_serde_attrs.rename, - // For enum variants, type structure is not used by generators - type_structure: crate::models::TypeStructure::Custom( - "enum_variant".to_string(), - ), - } - } - syn::Fields::Named(_fields_named) => { - // Struct variant: Variant { field: T } - // Note: Complex enum variants are not fully supported yet - FieldInfo { - name: variant_name, - rust_type: "enum_variant_struct".to_string(), - is_optional: false, - is_public: true, - validator_attributes: None, - serde_rename: variant_serde_attrs.rename, - // For enum variants, type structure is not used by generators - type_structure: crate::models::TypeStructure::Custom( - "enum_variant".to_string(), - ), - } - } + // Parse variants into both legacy fields (for backward compatibility) and new enum_variants + let mut fields = Vec::new(); + let mut enum_variants = Vec::new(); + + for variant in &item_enum.variants { + let variant_name = variant.ident.to_string(); + + // Parse variant-level serde attributes + let variant_serde_attrs = self.serde_parser.parse_field_serde_attrs(&variant.attrs); + + match &variant.fields { + syn::Fields::Unit => { + // Unit variant: Variant + fields.push(FieldInfo { + name: variant_name.clone(), + rust_type: "enum_variant".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: variant_serde_attrs.rename.clone(), + type_structure: TypeStructure::Primitive("string".to_string()), + }); + + enum_variants.push(EnumVariantInfo { + name: variant_name, + kind: EnumVariantKind::Unit, + serde_rename: variant_serde_attrs.rename, + }); + } + syn::Fields::Unnamed(fields_unnamed) => { + // Tuple variant: Variant(T, U) + let tuple_types: Vec = fields_unnamed + .unnamed + .iter() + .map(|field| { + let rust_type = Self::type_to_string(&field.ty); + type_resolver.parse_type_structure(&rust_type) + }) + .collect(); + + fields.push(FieldInfo { + name: variant_name.clone(), + rust_type: "enum_variant_tuple".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: variant_serde_attrs.rename.clone(), + type_structure: TypeStructure::Custom("enum_variant".to_string()), + }); + + enum_variants.push(EnumVariantInfo { + name: variant_name, + kind: EnumVariantKind::Tuple(tuple_types), + serde_rename: variant_serde_attrs.rename, + }); } - }) - .collect(); + syn::Fields::Named(fields_named) => { + // Struct variant: Variant { field: T } + let struct_fields: Vec = fields_named + .named + .iter() + .filter_map(|field| self.parse_field(field, type_resolver)) + .collect(); + + fields.push(FieldInfo { + name: variant_name.clone(), + rust_type: "enum_variant_struct".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: variant_serde_attrs.rename.clone(), + type_structure: TypeStructure::Custom("enum_variant".to_string()), + }); + + enum_variants.push(EnumVariantInfo { + name: variant_name, + kind: EnumVariantKind::Struct(struct_fields), + serde_rename: variant_serde_attrs.rename, + }); + } + } + } Some(StructInfo { name: item_enum.ident.to_string(), @@ -166,6 +196,8 @@ impl StructParser { file_path: file_path.to_string_lossy().to_string(), is_enum: true, serde_rename_all: enum_serde_attrs.rename_all, + serde_tag: enum_serde_attrs.tag, + enum_variants: Some(enum_variants), }) } @@ -519,6 +551,7 @@ mod tests { #[test] fn test_parse_simple_enum() { let parser = parser(); + let mut resolver = type_resolver(); let item: ItemEnum = parse_quote! { #[derive(Serialize)] pub enum Status { @@ -527,7 +560,7 @@ mod tests { } }; let path = Path::new("test.rs"); - let result = parser.parse_enum(&item, path); + let result = parser.parse_enum(&item, path, &mut resolver); assert!(result.is_some()); let enum_info = result.unwrap(); @@ -539,6 +572,7 @@ mod tests { #[test] fn test_parse_enum_unit_variants() { let parser = parser(); + let mut resolver = type_resolver(); let item: ItemEnum = parse_quote! { #[derive(Serialize)] pub enum Status { @@ -548,18 +582,26 @@ mod tests { } }; let path = Path::new("test.rs"); - let result = parser.parse_enum(&item, path).unwrap(); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); assert_eq!(result.fields.len(), 3); assert_eq!(result.fields[0].name, "Active"); assert_eq!(result.fields[0].rust_type, "enum_variant"); assert_eq!(result.fields[1].name, "Inactive"); assert_eq!(result.fields[2].name, "Pending"); + + // Check enum_variants are populated + let variants = result.enum_variants.as_ref().unwrap(); + assert_eq!(variants.len(), 3); + assert!(variants[0].is_unit()); + assert!(variants[1].is_unit()); + assert!(variants[2].is_unit()); } #[test] fn test_parse_enum_tuple_variant() { let parser = parser(); + let mut resolver = type_resolver(); let item: ItemEnum = parse_quote! { #[derive(Serialize)] pub enum Message { @@ -568,16 +610,71 @@ mod tests { } }; let path = Path::new("test.rs"); - let result = parser.parse_enum(&item, path).unwrap(); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); assert_eq!(result.fields.len(), 2); assert_eq!(result.fields[0].rust_type, "enum_variant_tuple"); assert_eq!(result.fields[1].rust_type, "enum_variant_tuple"); + + // Check enum_variants with tuple types + let variants = result.enum_variants.as_ref().unwrap(); + assert_eq!(variants.len(), 2); + assert!(variants[0].is_tuple()); + assert!(variants[1].is_tuple()); + + // Check tuple field types + let text_fields = variants[0].tuple_fields().unwrap(); + assert_eq!(text_fields.len(), 1); + assert_eq!( + text_fields[0], + crate::models::TypeStructure::Primitive("string".to_string()) + ); + + let number_fields = variants[1].tuple_fields().unwrap(); + assert_eq!(number_fields.len(), 1); + assert_eq!( + number_fields[0], + crate::models::TypeStructure::Primitive("number".to_string()) + ); + } + + #[test] + fn test_parse_enum_tuple_variant_multiple_fields() { + let parser = parser(); + let mut resolver = type_resolver(); + let item: ItemEnum = parse_quote! { + #[derive(Serialize)] + pub enum Message { + Move(i32, i32), + Point(f64, f64, f64), + } + }; + let path = Path::new("test.rs"); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); + + let variants = result.enum_variants.as_ref().unwrap(); + + // Check Move variant has 2 number fields + let move_fields = variants[0].tuple_fields().unwrap(); + assert_eq!(move_fields.len(), 2); + assert_eq!( + move_fields[0], + crate::models::TypeStructure::Primitive("number".to_string()) + ); + assert_eq!( + move_fields[1], + crate::models::TypeStructure::Primitive("number".to_string()) + ); + + // Check Point variant has 3 number fields + let point_fields = variants[1].tuple_fields().unwrap(); + assert_eq!(point_fields.len(), 3); } #[test] fn test_parse_enum_struct_variant() { let parser = parser(); + let mut resolver = type_resolver(); let item: ItemEnum = parse_quote! { #[derive(Serialize)] pub enum Message { @@ -585,15 +682,28 @@ mod tests { } }; let path = Path::new("test.rs"); - let result = parser.parse_enum(&item, path).unwrap(); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); assert_eq!(result.fields.len(), 1); assert_eq!(result.fields[0].rust_type, "enum_variant_struct"); + + // Check enum_variants with struct fields + let variants = result.enum_variants.as_ref().unwrap(); + assert_eq!(variants.len(), 1); + assert!(variants[0].is_struct()); + + let struct_fields = variants[0].struct_fields().unwrap(); + assert_eq!(struct_fields.len(), 2); + assert_eq!(struct_fields[0].name, "id"); + assert_eq!(struct_fields[0].rust_type, "i32"); + assert_eq!(struct_fields[1].name, "name"); + assert_eq!(struct_fields[1].rust_type, "String"); } #[test] fn test_parse_enum_with_serde_rename_variant() { let parser = parser(); + let mut resolver = type_resolver(); let item: ItemEnum = parse_quote! { #[derive(Serialize)] pub enum Status { @@ -604,15 +714,21 @@ mod tests { } }; let path = Path::new("test.rs"); - let result = parser.parse_enum(&item, path).unwrap(); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); assert_eq!(result.fields[0].serde_rename, Some("active".to_string())); assert_eq!(result.fields[1].serde_rename, Some("inactive".to_string())); + + // Check enum_variants also have serde_rename + let variants = result.enum_variants.as_ref().unwrap(); + assert_eq!(variants[0].serde_rename, Some("active".to_string())); + assert_eq!(variants[1].serde_rename, Some("inactive".to_string())); } #[test] fn test_parse_enum_with_rename_all() { let parser = parser(); + let mut resolver = type_resolver(); let item: ItemEnum = parse_quote! { #[derive(Serialize)] #[serde(rename_all = "snake_case")] @@ -622,10 +738,163 @@ mod tests { } }; let path = Path::new("test.rs"); - let result = parser.parse_enum(&item, path).unwrap(); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); assert_eq!(result.serde_rename_all, Some(RenameRule::SnakeCase)); } + + #[test] + fn test_parse_enum_with_serde_tag() { + let parser = parser(); + let mut resolver = type_resolver(); + let item: ItemEnum = parse_quote! { + #[derive(Serialize)] + #[serde(tag = "type")] + pub enum Message { + Text(String), + Quit, + } + }; + let path = Path::new("test.rs"); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); + + assert_eq!(result.serde_tag, Some("type".to_string())); + } + + #[test] + fn test_parse_enum_with_custom_tag() { + let parser = parser(); + let mut resolver = type_resolver(); + let item: ItemEnum = parse_quote! { + #[derive(Serialize)] + #[serde(tag = "kind")] + pub enum Action { + Start, + Stop, + } + }; + let path = Path::new("test.rs"); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); + + assert_eq!(result.serde_tag, Some("kind".to_string())); + assert_eq!(result.discriminator_tag(), "kind"); + } + + #[test] + fn test_parse_enum_mixed_variants() { + let parser = parser(); + let mut resolver = type_resolver(); + let item: ItemEnum = parse_quote! { + #[derive(Serialize)] + #[serde(tag = "type")] + pub enum Message { + Quit, + Move(i32, i32), + Write(String), + ChangeColor { r: u8, g: u8, b: u8 }, + } + }; + let path = Path::new("test.rs"); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); + + assert_eq!(result.serde_tag, Some("type".to_string())); + + let variants = result.enum_variants.as_ref().unwrap(); + assert_eq!(variants.len(), 4); + + // Quit is unit + assert!(variants[0].is_unit()); + assert_eq!(variants[0].name, "Quit"); + + // Move is tuple with 2 fields + assert!(variants[1].is_tuple()); + assert_eq!(variants[1].name, "Move"); + assert_eq!(variants[1].tuple_fields().unwrap().len(), 2); + + // Write is tuple with 1 field + assert!(variants[2].is_tuple()); + assert_eq!(variants[2].name, "Write"); + assert_eq!(variants[2].tuple_fields().unwrap().len(), 1); + + // ChangeColor is struct with 3 fields + assert!(variants[3].is_struct()); + assert_eq!(variants[3].name, "ChangeColor"); + let struct_fields = variants[3].struct_fields().unwrap(); + assert_eq!(struct_fields.len(), 3); + assert_eq!(struct_fields[0].name, "r"); + assert_eq!(struct_fields[1].name, "g"); + assert_eq!(struct_fields[2].name, "b"); + } + + #[test] + fn test_parse_enum_is_simple_vs_complex() { + let parser = parser(); + let mut resolver = type_resolver(); + + // Simple enum (all unit variants) + let simple_item: ItemEnum = parse_quote! { + #[derive(Serialize)] + pub enum Status { + Active, + Inactive, + } + }; + let path = Path::new("test.rs"); + let simple_result = parser + .parse_enum(&simple_item, path, &mut resolver) + .unwrap(); + assert!(simple_result.is_simple_enum()); + assert!(!simple_result.is_complex_enum()); + + // Complex enum (has tuple variant) + let complex_item: ItemEnum = parse_quote! { + #[derive(Serialize)] + pub enum Message { + Quit, + Text(String), + } + }; + let complex_result = parser + .parse_enum(&complex_item, path, &mut resolver) + .unwrap(); + assert!(!complex_result.is_simple_enum()); + assert!(complex_result.is_complex_enum()); + } + + #[test] + fn test_parse_enum_with_nested_types() { + let parser = parser(); + let mut resolver = type_resolver(); + let item: ItemEnum = parse_quote! { + #[derive(Serialize)] + pub enum Data { + List(Vec), + Map { items: HashMap }, + } + }; + let path = Path::new("test.rs"); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); + + let variants = result.enum_variants.as_ref().unwrap(); + + // Check List variant has Vec type + let list_fields = variants[0].tuple_fields().unwrap(); + assert_eq!(list_fields.len(), 1); + match &list_fields[0] { + crate::models::TypeStructure::Array(inner) => { + assert_eq!( + **inner, + crate::models::TypeStructure::Primitive("string".to_string()) + ); + } + _ => panic!("Expected Array type"), + } + + // Check Map variant has HashMap field + let map_fields = variants[1].struct_fields().unwrap(); + assert_eq!(map_fields.len(), 1); + assert_eq!(map_fields[0].name, "items"); + } } mod type_detection { @@ -846,9 +1115,10 @@ mod tests { #[test] fn test_parse_full_enum_with_all_features() { let parser = parser(); + let mut resolver = type_resolver(); let item: ItemEnum = parse_quote! { #[derive(Serialize, Deserialize)] - #[serde(rename_all = "snake_case")] + #[serde(rename_all = "snake_case", tag = "type")] pub enum Message { #[serde(rename = "simple")] Simple, @@ -857,18 +1127,35 @@ mod tests { } }; let path = Path::new("models.rs"); - let result = parser.parse_enum(&item, path).unwrap(); + let result = parser.parse_enum(&item, path, &mut resolver).unwrap(); assert_eq!(result.name, "Message"); assert_eq!(result.fields.len(), 3); assert_eq!(result.serde_rename_all, Some(RenameRule::SnakeCase)); assert!(result.is_enum); + assert_eq!(result.serde_tag, Some("type".to_string())); - // Check variant types + // Check variant types (legacy fields) assert_eq!(result.fields[0].rust_type, "enum_variant"); assert_eq!(result.fields[0].serde_rename, Some("simple".to_string())); assert_eq!(result.fields[1].rust_type, "enum_variant_tuple"); assert_eq!(result.fields[2].rust_type, "enum_variant_struct"); + + // Check enum_variants (new format) + let variants = result.enum_variants.as_ref().unwrap(); + assert_eq!(variants.len(), 3); + assert!(variants[0].is_unit()); + assert!(variants[1].is_tuple()); + assert!(variants[2].is_struct()); + + // Check Text tuple variant fields + let text_fields = variants[1].tuple_fields().unwrap(); + assert_eq!(text_fields.len(), 1); + + // Check User struct variant fields + let user_fields = variants[2].struct_fields().unwrap(); + assert_eq!(user_fields.len(), 1); + assert_eq!(user_fields[0].name, "id"); } } } diff --git a/src/build/generation_cache.rs b/src/build/generation_cache.rs index a6bf6ed..060b77d 100644 --- a/src/build/generation_cache.rs +++ b/src/build/generation_cache.rs @@ -466,6 +466,8 @@ mod tests { file_path: "test.rs".to_string(), is_enum: false, serde_rename_all: None, + serde_tag: None, + enum_variants: None, }; let struct_b = StructInfo { @@ -482,6 +484,8 @@ mod tests { file_path: "test.rs".to_string(), is_enum: false, serde_rename_all: None, + serde_tag: None, + enum_variants: None, }; // Insert in order A, B diff --git a/src/generators/base/template_context.rs b/src/generators/base/template_context.rs index f2b858d..9e3454b 100644 --- a/src/generators/base/template_context.rs +++ b/src/generators/base/template_context.rs @@ -1,5 +1,7 @@ use crate::generators::base::type_visitor::TypeVisitor; -use crate::models::{ChannelInfo, CommandInfo, EventInfo, FieldInfo, ParameterInfo}; +use crate::models::{ + ChannelInfo, CommandInfo, EnumVariantInfo, EnumVariantKind, EventInfo, FieldInfo, ParameterInfo, +}; use crate::{GenerateConfig, TypeStructure}; use serde::{Deserialize, Serialize}; use serde_rename_rule::RenameRule; @@ -327,6 +329,89 @@ impl FieldContext { } } +/// Template context wrapper for enum variants with computed TypeScript-specific fields +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EnumVariantContext { + /// Original variant name + pub name: String, + /// Serialized name (after applying serde rename rules) + pub serialized_name: String, + /// Variant kind: "unit", "tuple", or "struct" + pub kind: String, + /// TypeScript types for tuple variant fields (e.g., ["number", "number"] for Move(i32, i32)) + pub tuple_types: Vec, + /// Zod schemas for tuple variant fields (e.g., ["z.number()", "z.number()"]) + pub tuple_zod_types: Vec, + /// Field contexts for struct variant fields + pub struct_fields: Vec, + #[serde(skip)] + config: GenerateConfig, +} + +impl NamingContext for EnumVariantContext { + fn config(&self) -> &GenerateConfig { + &self.config + } +} + +impl EnumVariantContext { + /// Create a new EnumVariantContext with the given config + pub fn new(config: &GenerateConfig) -> Self { + Self { + name: String::new(), + serialized_name: String::new(), + kind: String::new(), + tuple_types: Vec::new(), + tuple_zod_types: Vec::new(), + struct_fields: Vec::new(), + config: config.clone(), + } + } + + /// Populate this context from an EnumVariantInfo + pub fn from_variant_info( + mut self, + variant: &EnumVariantInfo, + enum_rename_all: &Option, + visitor: &V, + ) -> Self { + // Compute serialized name from serde attributes + let serialized_name = + self.compute_field_name(&variant.name, &variant.serde_rename, enum_rename_all); + + self.name = variant.name.clone(); + self.serialized_name = serialized_name; + + match &variant.kind { + EnumVariantKind::Unit => { + self.kind = "unit".to_string(); + } + EnumVariantKind::Tuple(types) => { + self.kind = "tuple".to_string(); + self.tuple_types = types + .iter() + .map(|t| visitor.visit_type_for_interface(t)) + .collect(); + self.tuple_zod_types = types.iter().map(|t| visitor.visit_type(t)).collect(); + } + EnumVariantKind::Struct(fields) => { + self.kind = "struct".to_string(); + self.struct_fields = fields + .iter() + .map(|field| { + // Struct variant fields don't inherit enum's rename_all + // They use their own serde attributes + FieldContext::new(&self.config).from_field_info(field, &None, visitor) + }) + .collect(); + } + } + + self + } +} + /// Template context wrapper for StructInfo with computed TypeScript-specific fields #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -334,6 +419,12 @@ pub struct StructContext { pub name: String, pub fields: Vec, pub is_enum: bool, + /// Whether this is a simple enum (all unit variants) - can use string literal union + pub is_simple_enum: bool, + /// The discriminator tag name for complex enums (default: "type") + pub discriminator_tag: String, + /// Enum variants with full type information (only for enums) + pub enum_variants: Vec, #[serde(skip)] config: GenerateConfig, } @@ -351,6 +442,9 @@ impl StructContext { name: String::new(), fields: Vec::new(), is_enum: false, + is_simple_enum: false, + discriminator_tag: "type".to_string(), + enum_variants: Vec::new(), config: config.clone(), } } @@ -374,9 +468,30 @@ impl StructContext { }) .collect(); + // Build enum variant contexts if this is an enum with enum_variants + let enum_variants: Vec = struct_info + .enum_variants + .as_ref() + .map(|variants| { + variants + .iter() + .map(|v| { + EnumVariantContext::new(&self.config).from_variant_info( + v, + &struct_info.serde_rename_all, + visitor, + ) + }) + .collect() + }) + .unwrap_or_default(); + self.name = name.to_string(); self.fields = field_contexts; self.is_enum = struct_info.is_enum; + self.is_simple_enum = struct_info.is_simple_enum(); + self.discriminator_tag = struct_info.discriminator_tag().to_string(); + self.enum_variants = enum_variants; self } @@ -800,6 +915,22 @@ mod tests { assert_eq!(ctx.name, ""); assert_eq!(ctx.fields.len(), 0); assert!(!ctx.is_enum); + assert!(!ctx.is_simple_enum); + assert_eq!(ctx.discriminator_tag, "type"); + assert_eq!(ctx.enum_variants.len(), 0); + } + + #[test] + fn test_enum_variant_context_builder_pattern() { + let config = mock_config(); + let ctx = EnumVariantContext::new(&config); + + assert_eq!(ctx.name, ""); + assert_eq!(ctx.serialized_name, ""); + assert_eq!(ctx.kind, ""); + assert_eq!(ctx.tuple_types.len(), 0); + assert_eq!(ctx.tuple_zod_types.len(), 0); + assert_eq!(ctx.struct_fields.len(), 0); } #[test] diff --git a/src/generators/mod.rs b/src/generators/mod.rs index ac38563..995a158 100644 --- a/src/generators/mod.rs +++ b/src/generators/mod.rs @@ -435,6 +435,8 @@ mod tests { file_path: "test.rs".to_string(), is_enum: false, serde_rename_all: None, + serde_tag: None, + enum_variants: None, } } @@ -561,6 +563,8 @@ mod tests { file_path: "test.rs".to_string(), is_enum: false, serde_rename_all: None, + serde_tag: None, + enum_variants: None, } } diff --git a/src/generators/ts/templates/partials/enum.tera b/src/generators/ts/templates/partials/enum.tera index 2c25830..b945166 100644 --- a/src/generators/ts/templates/partials/enum.tera +++ b/src/generators/ts/templates/partials/enum.tera @@ -1,3 +1,18 @@ +{% if struct.isSimpleEnum -%} +{# Simple enum: string literal union #} export type {{ name }} = {% for field in fields -%} "{{ field.serializedName }}"{% if not loop.last %} | {% endif %} {%- endfor %}; +{% else -%} +{# Complex enum: discriminated union #} +export type {{ name }} = +{% for variant in struct.enumVariants -%} + | {% if variant.kind == "unit" -%} +{ {{ struct.discriminatorTag }}: "{{ variant.serializedName }}" } +{%- elif variant.kind == "tuple" -%} +{ {{ struct.discriminatorTag }}: "{{ variant.serializedName }}"; data: {% if variant.tupleTypes | length == 1 %}{{ variant.tupleTypes | first }}{% else %}[{% for t in variant.tupleTypes %}{{ t }}{% if not loop.last %}, {% endif %}{% endfor %}]{% endif %} } +{%- elif variant.kind == "struct" -%} +{ {{ struct.discriminatorTag }}: "{{ variant.serializedName }}"{% for field in variant.structFields %}; {{ field.serializedName }}: {{ field.typescriptType }}{% endfor %} } +{%- endif %} +{% endfor -%}; +{% endif -%} diff --git a/src/generators/zod/generator.rs b/src/generators/zod/generator.rs index 9d20bd3..c14d141 100644 --- a/src/generators/zod/generator.rs +++ b/src/generators/zod/generator.rs @@ -1,6 +1,6 @@ use crate::analysis::CommandAnalyzer; use crate::generators::base::file_writer::FileWriter; -use crate::generators::base::template_context::FieldContext; +use crate::generators::base::template_context::{FieldContext, StructContext}; use crate::generators::base::templates::TemplateRegistry; use crate::generators::base::BaseBindingsGenerator; use crate::generators::zod::schema_builder::ZodSchemaBuilder; @@ -40,7 +40,7 @@ impl ZodBindingsGenerator { } } - /// Generate Zod schema for an enum + /// Generate Zod schema for an enum using templates fn generate_enum_schema( &self, name: &str, @@ -48,22 +48,34 @@ impl ZodBindingsGenerator { config: &GenerateConfig, ) -> String { let visitor = ZodVisitor::with_config(config); + let schema_builder = ZodSchemaBuilder::new(config); - // Convert fields to context to get serialized names - let field_contexts: Vec = - self.collector - .create_field_contexts(struct_info, &visitor, config); + // Create StructContext with all enum information + let mut struct_context = + StructContext::new(config).from_struct_info(name, struct_info, &visitor); + + // For complex enums, enrich struct variant fields with proper Zod schemas + if !struct_info.is_simple_enum() { + for variant in &mut struct_context.enum_variants { + for field in &mut variant.struct_fields { + let zod_schema = schema_builder + .build_schema(&field.type_structure, &field.validator_attributes); + field.typescript_type = zod_schema; + } + } + } - let variants: Vec = field_contexts - .iter() - .map(|field| format!("\"{}\"", field.serialized_name)) - .collect(); + // Prepare template context + let mut context = Context::new(); + context.insert("name", name); + context.insert("struct", &struct_context); + context.insert("fields", &struct_context.fields); - let enum_values = variants.join(", "); - format!( - "export const {}Schema = z.enum([{}]);\n\n", - name, enum_values - ) + self.render("zod/partials/enum_schema.ts.tera", &context) + .unwrap_or_else(|e| { + eprintln!("Template rendering failed for enum {}: {}", name, e); + format!("// Error generating schema for {}: {}\n", name, e) + }) } /// Generate Zod schema for an object/struct using templates @@ -432,9 +444,47 @@ mod tests { } fn create_test_struct(is_enum: bool) -> StructInfo { - StructInfo { - name: "TestStruct".to_string(), - fields: vec![FieldInfo { + use crate::models::{EnumVariantInfo, EnumVariantKind}; + + let (fields, enum_variants) = if is_enum { + // For enums, create proper enum_variants + let variants = vec![ + EnumVariantInfo { + name: "Variant1".to_string(), + kind: EnumVariantKind::Unit, + serde_rename: None, + }, + EnumVariantInfo { + name: "Variant2".to_string(), + kind: EnumVariantKind::Unit, + serde_rename: None, + }, + ]; + // Legacy fields for backward compatibility + let fields = vec![ + FieldInfo { + name: "Variant1".to_string(), + rust_type: "enum_variant".to_string(), + is_optional: false, + is_public: true, + type_structure: TypeStructure::Primitive("string".to_string()), + serde_rename: None, + validator_attributes: None, + }, + FieldInfo { + name: "Variant2".to_string(), + rust_type: "enum_variant".to_string(), + is_optional: false, + is_public: true, + type_structure: TypeStructure::Primitive("string".to_string()), + serde_rename: None, + validator_attributes: None, + }, + ]; + (fields, Some(variants)) + } else { + // For structs, create normal fields + let fields = vec![FieldInfo { name: "test_field".to_string(), rust_type: "String".to_string(), is_optional: false, @@ -442,10 +492,18 @@ mod tests { type_structure: TypeStructure::Primitive("string".to_string()), serde_rename: None, validator_attributes: None, - }], + }]; + (fields, None) + }; + + StructInfo { + name: "TestStruct".to_string(), + fields, file_path: "test.rs".to_string(), is_enum, serde_rename_all: None, + serde_tag: None, + enum_variants, } } diff --git a/src/generators/zod/templates/partials/enum_schema.ts.tera b/src/generators/zod/templates/partials/enum_schema.ts.tera index c2f4da9..39b332d 100644 --- a/src/generators/zod/templates/partials/enum_schema.ts.tera +++ b/src/generators/zod/templates/partials/enum_schema.ts.tera @@ -1,5 +1,23 @@ -export const {{ name }}Schema = {% for field in fields -%} -{{ field.typeStructure | to_zod_schema }}{% if not loop.last %}.or({% endif %} -{%- endfor %}{% for field in fields %}{% if not loop.first %}){% endif %}{% endfor %}; +{% if struct.isSimpleEnum -%} +{# Simple enum: z.enum #} +export const {{ name }}Schema = z.enum([{% for field in fields -%} +"{{ field.serializedName }}"{% if not loop.last %}, {% endif %} +{%- endfor %}]); export type {{ name }} = z.infer; +{% else -%} +{# Complex enum: z.discriminatedUnion #} +export const {{ name }}Schema = z.discriminatedUnion("{{ struct.discriminatorTag }}", [ +{% for variant in struct.enumVariants -%} + {% if variant.kind == "unit" -%} + z.object({ {{ struct.discriminatorTag }}: z.literal("{{ variant.serializedName }}") }), + {%- elif variant.kind == "tuple" -%} + z.object({ {{ struct.discriminatorTag }}: z.literal("{{ variant.serializedName }}"), data: {% if variant.tupleZodTypes | length == 1 %}{{ variant.tupleZodTypes | first }}{% else %}z.tuple([{% for t in variant.tupleZodTypes %}{{ t }}{% if not loop.last %}, {% endif %}{% endfor %}]){% endif %} }), + {%- elif variant.kind == "struct" -%} + z.object({ {{ struct.discriminatorTag }}: z.literal("{{ variant.serializedName }}"){% for field in variant.structFields %}, {{ field.serializedName }}: {{ field.typescriptType }}{% endfor %} }), + {%- endif %} +{% endfor -%} +]); + +export type {{ name }} = z.infer; +{% endif -%} diff --git a/src/models.rs b/src/models.rs index 92e725a..e2f9257 100644 --- a/src/models.rs +++ b/src/models.rs @@ -3,7 +3,7 @@ use serde_rename_rule::RenameRule; /// Represents the structure of a type for code generation /// This allows generators to work with parsed type information instead of string parsing -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] pub enum TypeStructure { /// Primitive types: "string", "number", "boolean", "void" @@ -41,6 +41,62 @@ impl Default for TypeStructure { } } +/// Represents the kind of an enum variant for discriminated union generation +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub enum EnumVariantKind { + /// Unit variant: `Quit` + Unit, + /// Tuple variant with unnamed fields: `Move(i32, i32)` or `Write(String)` + Tuple(Vec), + /// Struct variant with named fields: `ChangeColor { r: u8, g: u8, b: u8 }` + Struct(Vec), +} + +/// Information about an enum variant for discriminated union generation +#[derive(Debug, Clone)] +pub struct EnumVariantInfo { + /// The variant name (e.g., "Quit", "Move", "ChangeColor") + pub name: String, + /// The kind of variant and its associated data + pub kind: EnumVariantKind, + /// Serde rename attribute: #[serde(rename = "...")] + pub serde_rename: Option, +} + +impl EnumVariantInfo { + /// Returns true if this is a unit variant (no associated data) + pub fn is_unit(&self) -> bool { + matches!(self.kind, EnumVariantKind::Unit) + } + + /// Returns true if this is a tuple variant (unnamed fields) + pub fn is_tuple(&self) -> bool { + matches!(self.kind, EnumVariantKind::Tuple(_)) + } + + /// Returns true if this is a struct variant (named fields) + pub fn is_struct(&self) -> bool { + matches!(self.kind, EnumVariantKind::Struct(_)) + } + + /// Returns the tuple fields if this is a tuple variant + pub fn tuple_fields(&self) -> Option<&Vec> { + match &self.kind { + EnumVariantKind::Tuple(fields) => Some(fields), + _ => None, + } + } + + /// Returns the struct fields if this is a struct variant + pub fn struct_fields(&self) -> Option<&Vec> { + match &self.kind { + EnumVariantKind::Struct(fields) => Some(fields), + _ => None, + } + } +} + pub struct CommandInfo { pub name: String, pub file_path: String, @@ -106,9 +162,44 @@ pub struct StructInfo { pub is_enum: bool, /// Serde rename_all attribute: #[serde(rename_all = "...")] pub serde_rename_all: Option, + /// Serde tag attribute for enums: #[serde(tag = "...")] + /// Used for internally-tagged enum representation + pub serde_tag: Option, + /// Enum variants with full type information (only populated for enums) + /// When populated, provides richer variant data than the `fields` vector + pub enum_variants: Option>, } -#[derive(Clone, Debug)] +impl StructInfo { + /// Returns true if this is a simple enum (all unit variants) + /// Simple enums can be represented as TypeScript string literal unions + pub fn is_simple_enum(&self) -> bool { + if !self.is_enum { + return false; + } + + match &self.enum_variants { + Some(variants) => variants.iter().all(|v| v.is_unit()), + // Fallback to checking fields for backward compatibility + None => self.fields.iter().all(|f| f.rust_type == "enum_variant"), + } + } + + /// Returns true if this is a complex enum (has tuple or struct variants) + /// Complex enums need discriminated union representation in TypeScript + pub fn is_complex_enum(&self) -> bool { + self.is_enum && !self.is_simple_enum() + } + + /// Returns the discriminator tag name for this enum + /// Defaults to "type" if not specified via #[serde(tag = "...")] + pub fn discriminator_tag(&self) -> &str { + self.serde_tag.as_deref().unwrap_or("type") + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] pub struct FieldInfo { pub name: String, pub rust_type: String, @@ -121,7 +212,7 @@ pub struct FieldInfo { pub type_structure: TypeStructure, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] pub struct ValidatorAttributes { pub length: Option, @@ -131,7 +222,7 @@ pub struct ValidatorAttributes { pub custom_message: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] pub struct LengthConstraint { pub min: Option, @@ -139,7 +230,7 @@ pub struct LengthConstraint { pub message: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] pub struct RangeConstraint { pub min: Option, @@ -731,6 +822,8 @@ mod tests { file_path: "src/models.rs".to_string(), is_enum: false, serde_rename_all: None, + serde_tag: None, + enum_variants: None, }; assert_eq!(struct_info.name, "User"); @@ -746,6 +839,8 @@ mod tests { file_path: "src/types.rs".to_string(), is_enum: true, serde_rename_all: Some(RenameRule::CamelCase), + serde_tag: None, + enum_variants: None, }; assert!(struct_info.is_enum); @@ -760,12 +855,327 @@ mod tests { file_path: "src/product.rs".to_string(), is_enum: false, serde_rename_all: None, + serde_tag: None, + enum_variants: None, }; let cloned = original.clone(); assert_eq!(cloned.name, "Product"); assert!(!cloned.is_enum); } + + #[test] + fn test_simple_enum_detection() { + // Simple enum with unit variants only + let simple_enum = StructInfo { + name: "Status".to_string(), + fields: vec![ + FieldInfo { + name: "Active".to_string(), + rust_type: "enum_variant".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: None, + type_structure: TypeStructure::Custom("enum_variant".to_string()), + }, + FieldInfo { + name: "Inactive".to_string(), + rust_type: "enum_variant".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: None, + type_structure: TypeStructure::Custom("enum_variant".to_string()), + }, + ], + file_path: "src/types.rs".to_string(), + is_enum: true, + serde_rename_all: None, + serde_tag: None, + enum_variants: None, + }; + + assert!(simple_enum.is_simple_enum()); + assert!(!simple_enum.is_complex_enum()); + } + + #[test] + fn test_complex_enum_detection_via_fields() { + // Complex enum detected via rust_type field (backward compatibility) + let complex_enum = StructInfo { + name: "Message".to_string(), + fields: vec![ + FieldInfo { + name: "Quit".to_string(), + rust_type: "enum_variant".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: None, + type_structure: TypeStructure::Custom("enum_variant".to_string()), + }, + FieldInfo { + name: "Move".to_string(), + rust_type: "enum_variant_tuple".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: None, + type_structure: TypeStructure::Custom("enum_variant".to_string()), + }, + ], + file_path: "src/types.rs".to_string(), + is_enum: true, + serde_rename_all: None, + serde_tag: None, + enum_variants: None, + }; + + assert!(!complex_enum.is_simple_enum()); + assert!(complex_enum.is_complex_enum()); + } + + #[test] + fn test_complex_enum_detection_via_enum_variants() { + // Complex enum with EnumVariantInfo populated + let complex_enum = StructInfo { + name: "Message".to_string(), + fields: vec![], + file_path: "src/types.rs".to_string(), + is_enum: true, + serde_rename_all: None, + serde_tag: Some("type".to_string()), + enum_variants: Some(vec![ + EnumVariantInfo { + name: "Quit".to_string(), + kind: EnumVariantKind::Unit, + serde_rename: None, + }, + EnumVariantInfo { + name: "Move".to_string(), + kind: EnumVariantKind::Tuple(vec![ + TypeStructure::Primitive("number".to_string()), + TypeStructure::Primitive("number".to_string()), + ]), + serde_rename: None, + }, + ]), + }; + + assert!(!complex_enum.is_simple_enum()); + assert!(complex_enum.is_complex_enum()); + } + + #[test] + fn test_discriminator_tag_default() { + let enum_info = StructInfo { + name: "Status".to_string(), + fields: vec![], + file_path: "src/types.rs".to_string(), + is_enum: true, + serde_rename_all: None, + serde_tag: None, + enum_variants: None, + }; + + assert_eq!(enum_info.discriminator_tag(), "type"); + } + + #[test] + fn test_discriminator_tag_custom() { + let enum_info = StructInfo { + name: "Status".to_string(), + fields: vec![], + file_path: "src/types.rs".to_string(), + is_enum: true, + serde_rename_all: None, + serde_tag: Some("kind".to_string()), + enum_variants: None, + }; + + assert_eq!(enum_info.discriminator_tag(), "kind"); + } + } + + // EnumVariantKind tests + mod enum_variant_kind { + use super::*; + + #[test] + fn test_unit_variant() { + let kind = EnumVariantKind::Unit; + assert_eq!(kind, EnumVariantKind::Unit); + } + + #[test] + fn test_tuple_variant_single_field() { + let kind = EnumVariantKind::Tuple(vec![TypeStructure::Primitive("string".to_string())]); + + match kind { + EnumVariantKind::Tuple(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0], TypeStructure::Primitive("string".to_string())); + } + _ => panic!("Should be Tuple variant"), + } + } + + #[test] + fn test_tuple_variant_multiple_fields() { + let kind = EnumVariantKind::Tuple(vec![ + TypeStructure::Primitive("number".to_string()), + TypeStructure::Primitive("number".to_string()), + ]); + + match kind { + EnumVariantKind::Tuple(fields) => { + assert_eq!(fields.len(), 2); + } + _ => panic!("Should be Tuple variant"), + } + } + + #[test] + fn test_struct_variant() { + let fields = vec![ + FieldInfo { + name: "r".to_string(), + rust_type: "u8".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: None, + type_structure: TypeStructure::Primitive("number".to_string()), + }, + FieldInfo { + name: "g".to_string(), + rust_type: "u8".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: None, + type_structure: TypeStructure::Primitive("number".to_string()), + }, + ]; + let kind = EnumVariantKind::Struct(fields); + + match kind { + EnumVariantKind::Struct(f) => { + assert_eq!(f.len(), 2); + assert_eq!(f[0].name, "r"); + assert_eq!(f[1].name, "g"); + } + _ => panic!("Should be Struct variant"), + } + } + + #[test] + fn test_serialize_deserialize() { + let unit = EnumVariantKind::Unit; + let json = serde_json::to_string(&unit).unwrap(); + let deserialized: EnumVariantKind = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, EnumVariantKind::Unit); + + let tuple = + EnumVariantKind::Tuple(vec![TypeStructure::Primitive("string".to_string())]); + let json = serde_json::to_string(&tuple).unwrap(); + let deserialized: EnumVariantKind = serde_json::from_str(&json).unwrap(); + match deserialized { + EnumVariantKind::Tuple(fields) => assert_eq!(fields.len(), 1), + _ => panic!("Should deserialize to Tuple"), + } + } + } + + // EnumVariantInfo tests + mod enum_variant_info { + use super::*; + + #[test] + fn test_unit_variant_helpers() { + let variant = EnumVariantInfo { + name: "Quit".to_string(), + kind: EnumVariantKind::Unit, + serde_rename: None, + }; + + assert!(variant.is_unit()); + assert!(!variant.is_tuple()); + assert!(!variant.is_struct()); + assert!(variant.tuple_fields().is_none()); + assert!(variant.struct_fields().is_none()); + } + + #[test] + fn test_tuple_variant_helpers() { + let variant = EnumVariantInfo { + name: "Move".to_string(), + kind: EnumVariantKind::Tuple(vec![ + TypeStructure::Primitive("number".to_string()), + TypeStructure::Primitive("number".to_string()), + ]), + serde_rename: None, + }; + + assert!(!variant.is_unit()); + assert!(variant.is_tuple()); + assert!(!variant.is_struct()); + + let fields = variant.tuple_fields().unwrap(); + assert_eq!(fields.len(), 2); + assert!(variant.struct_fields().is_none()); + } + + #[test] + fn test_struct_variant_helpers() { + let variant = EnumVariantInfo { + name: "ChangeColor".to_string(), + kind: EnumVariantKind::Struct(vec![FieldInfo { + name: "r".to_string(), + rust_type: "u8".to_string(), + is_optional: false, + is_public: true, + validator_attributes: None, + serde_rename: None, + type_structure: TypeStructure::Primitive("number".to_string()), + }]), + serde_rename: None, + }; + + assert!(!variant.is_unit()); + assert!(!variant.is_tuple()); + assert!(variant.is_struct()); + + assert!(variant.tuple_fields().is_none()); + let fields = variant.struct_fields().unwrap(); + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name, "r"); + } + + #[test] + fn test_variant_with_serde_rename() { + let variant = EnumVariantInfo { + name: "Quit".to_string(), + kind: EnumVariantKind::Unit, + serde_rename: Some("quit".to_string()), + }; + + assert_eq!(variant.serde_rename, Some("quit".to_string())); + } + + #[test] + fn test_clone() { + let original = EnumVariantInfo { + name: "Write".to_string(), + kind: EnumVariantKind::Tuple(vec![TypeStructure::Primitive("string".to_string())]), + serde_rename: None, + }; + + let cloned = original.clone(); + assert_eq!(cloned.name, "Write"); + assert!(cloned.is_tuple()); + } } // FieldInfo tests diff --git a/tests/integration_e2e.rs b/tests/integration_e2e.rs index 755fa0d..6d563ba 100644 --- a/tests/integration_e2e.rs +++ b/tests/integration_e2e.rs @@ -511,3 +511,212 @@ fn test_event_payload_discovery_from_helper_function() { events_file ); } + +/// Test complex enum (discriminated union) TypeScript generation +#[test] +fn test_complex_enum_typescript_generation() { + let project = TestProject::new(); + + project.write_file( + "main.rs", + r#" + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Clone, Serialize, Deserialize)] + #[serde(tag = "type")] + pub enum Message { + Quit, + Move(i32, i32), + Write(String), + ChangeColor { r: u8, g: u8, b: u8 }, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub enum Status { + Active, + Inactive, + Pending, + } + + #[tauri::command] + pub fn send_message(msg: Message) -> Result { + Ok(Status::Active) + } + "#, + ); + + let (analyzer, commands) = project.analyze(); + let generator = TestGenerator::new(); + generator.generate( + &commands, + analyzer.get_discovered_structs(), + &analyzer, + Some("none"), + None, + ); + + let types = generator.read_file("types.ts"); + + // Simple enum should be string literal union + assert!( + types.contains(r#"export type Status = "Active" | "Inactive" | "Pending";"#), + "Simple enum should use string literal union. Got:\n{}", + types + ); + + // Complex enum should be discriminated union + assert!( + types.contains("export type Message ="), + "Complex enum should have type declaration. Got:\n{}", + types + ); + + // Check for discriminated union structure + assert!( + types.contains(r#"type: "Quit""#), + "Should have Quit variant. Got:\n{}", + types + ); + assert!( + types.contains(r#"type: "Move""#), + "Should have Move variant. Got:\n{}", + types + ); + assert!( + types.contains(r#"type: "Write""#), + "Should have Write variant. Got:\n{}", + types + ); + assert!( + types.contains(r#"type: "ChangeColor""#), + "Should have ChangeColor variant. Got:\n{}", + types + ); + + // Check tuple variant has data field + assert!( + types.contains("data:"), + "Tuple variants should have data field. Got:\n{}", + types + ); + + // Check struct variant has named fields + assert!( + types.contains("r: number"), + "Struct variant should have r field. Got:\n{}", + types + ); +} + +/// Test complex enum Zod schema generation (discriminated union) +#[test] +fn test_complex_enum_zod_generation() { + let project = TestProject::new(); + + project.write_file( + "main.rs", + r#" + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Clone, Serialize, Deserialize)] + #[serde(tag = "kind")] + pub enum Action { + Start, + Move { x: i32, y: i32 }, + Send(String), + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub enum Status { + Active, + Inactive, + } + + #[tauri::command] + pub fn perform_action(action: Action) -> Result { + Ok(Status::Active) + } + "#, + ); + + let (analyzer, commands) = project.analyze(); + let generator = TestGenerator::new(); + generator.generate( + &commands, + analyzer.get_discovered_structs(), + &analyzer, + Some("zod"), + None, + ); + + let types = generator.read_file("types.ts"); + + // Simple enum should use z.enum + assert!( + types.contains("StatusSchema = z.enum"), + "Simple enum should use z.enum. Got:\n{}", + types + ); + assert!( + types.contains(r#"["Active", "Inactive"]"#), + "Simple enum should list variants. Got:\n{}", + types + ); + + // Complex enum should use z.discriminatedUnion + assert!( + types.contains("ActionSchema = z.discriminatedUnion"), + "Complex enum should use z.discriminatedUnion. Got:\n{}", + types + ); + + // Check discriminator is "kind" (from serde tag) + assert!( + types.contains(r#"z.discriminatedUnion("kind""#), + "Should use 'kind' as discriminator. Got:\n{}", + types + ); + + // Check unit variant + assert!( + types.contains(r#"z.literal("Start")"#), + "Should have Start variant. Got:\n{}", + types + ); + + // Check struct variant with named fields + assert!( + types.contains(r#"z.literal("Move")"#), + "Should have Move variant. Got:\n{}", + types + ); + assert!( + types.contains("x:") && types.contains("y:"), + "Move variant should have x and y fields. Got:\n{}", + types + ); + + // Check tuple variant with data field + assert!( + types.contains(r#"z.literal("Send")"#), + "Should have Send variant. Got:\n{}", + types + ); + assert!( + types.contains("data:"), + "Tuple variant should have data field. Got:\n{}", + types + ); + + // Verify type inference + assert!( + types.contains("export type Action = z.infer"), + "Should export inferred Action type. Got:\n{}", + types + ); + assert!( + types.contains("export type Status = z.infer"), + "Should export inferred Status type. Got:\n{}", + types + ); +}