Skip to content

Add safe attribute helper to SafeDebug #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/typespec/typespec_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub fn derive_model(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
/// assert_eq!(format!("{model:?}"), "MyModel { .. }");
/// }
/// ```
#[proc_macro_derive(SafeDebug)]
#[proc_macro_derive(SafeDebug, attributes(safe))]
pub fn derive_safe_debug(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
run_derive_macro(input, safe_debug::derive_safe_debug_impl)
}
283 changes: 247 additions & 36 deletions sdk/typespec/typespec_macros/src/safe_debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::Result;
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DataEnum, DataStruct,
punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, Data, DataEnum, DataStruct,
DeriveInput, Error, Field, Fields, FieldsNamed, FieldsUnnamed, Ident, Path,
};

Expand All @@ -26,14 +26,21 @@ fn generate_body(ast: DeriveInput) -> Result<TokenStream> {
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
let name = &ast.ident;

let type_attrs = Attrs::from_attrs(&ast.attrs)?;
let body = match &ast.data {
Data::Enum(DataEnum { variants, .. }) => {
let variants = variants.iter().map(|v| {
let variant_name = &v.ident;
let path = to_path(&[name, variant_name]);
let variants = variants
.iter()
.map(|v| -> Result<TokenStream> {
let variant_name = &v.ident;
let path = to_path(&[name, variant_name]);

let mut enum_attrs = Attrs::from_attrs(&v.attrs)?;
enum_attrs.and(&type_attrs);

generate_fields(&path, &v.fields)
});
generate_fields(&path, &enum_attrs, &v.fields)
})
.collect::<Result<Vec<_>>>()?;

quote! {
match self {
Expand All @@ -43,7 +50,7 @@ fn generate_body(ast: DeriveInput) -> Result<TokenStream> {
}
Data::Struct(DataStruct { fields, .. }) => {
let path = to_path(&[name]);
let fields = generate_fields(&path, fields);
let fields = generate_fields(&path, &type_attrs, fields)?;

quote! {
match self {
Expand All @@ -64,22 +71,28 @@ fn generate_body(ast: DeriveInput) -> Result<TokenStream> {
})
}

fn generate_fields(path: &Path, fields: &Fields) -> TokenStream {
fn generate_fields(path: &Path, type_attrs: &Attrs, fields: &Fields) -> Result<TokenStream> {
let name = &path.segments.last().expect("expected identifier").ident;
let name_str = name.to_string();

match fields {
Fields::Named(FieldsNamed { ref named, .. }) => {
let names: Vec<&Ident> = if cfg!(feature = "debug") {
named
.iter()
.map(|f| f.ident.as_ref().expect("expected named field"))
.collect()
} else {
// Should we ever add a `#[safe(bool)]` helper attribute to denote which fields we can safely include,
// filter the fields to match and emit based on the inherited value or field attribute value.
Vec::new()
};
let names: Vec<&Ident> = named
.iter()
.filter_map(|f| -> Option<Result<&Ident>> {
if cfg!(feature = "debug") {
return Some(Ok(f.ident.as_ref().expect("expected named field")));
}

match Attrs::from_attrs(&f.attrs) {
Err(err) => Some(Err(err)),
Ok(attrs) if type_attrs.is_safe_and(&attrs) => {
Some(Ok(f.ident.as_ref().expect("expected named field")))
}
Ok(_) => None,
}
})
.collect::<Result<Vec<_>>>()?;
let fields: Vec<TokenStream> = names
.iter()
.map(|field_name| {
Expand All @@ -90,37 +103,45 @@ fn generate_fields(path: &Path, fields: &Fields) -> TokenStream {

// Use an "and the rest" matcher as needed, along with the appropriate `DebugStruct` finisher.
let (matcher, finisher) = finish(&fields, named, false);
quote! {
Ok(quote! {
#path { #(#names),* #matcher } => f
.debug_struct(#name_str)
#(#fields)*
#finisher
}
})
}
Fields::Unit => quote! {#path => f.write_str(#name_str)},
Fields::Unit => Ok(quote! {#path => f.write_str(#name_str)}),
Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }) => {
let indices: Vec<TokenStream> = if cfg!(feature = "debug") {
unnamed
.iter()
.enumerate()
.map(|(i, _)| {
Ident::new(&format!("f{i}"), Span::call_site()).into_token_stream()
})
.collect()
} else {
// Should we ever add a `#[safe(bool)]` helper attribute to denote which fields we can safely include,
// filter the fields to match and emit based on the inherited value or field attribute value.
Vec::new()
};
let indices: Vec<TokenStream> = unnamed
.iter()
.enumerate()
.filter_map(|(i, f)| {
if cfg!(feature = "debug") {
return Some(Ok(
Ident::new(&format!("f{i}"), Span::call_site()).into_token_stream()
));
}

match Attrs::from_attrs(&f.attrs) {
Err(err) => Some(Err(err)),
Ok(attrs) if type_attrs.is_safe_and(&attrs) => {
Some(Ok(
Ident::new(&format!("f{i}"), Span::call_site()).into_token_stream()
))
}
Ok(_) => None,
}
})
.collect::<Result<Vec<_>>>()?;

// Use an "and the rest" matcher as needed, along with the appropriate `DebugTuple` finisher.
let (matcher, finisher) = finish(&indices, unnamed, true);
quote! {
Ok(quote! {
#path(#(#indices),* #matcher) => f
.debug_tuple(#name_str)
#(.field(&#indices))*
#finisher
}
})
}
}
}
Expand Down Expand Up @@ -166,3 +187,193 @@ fn to_path(idents: &[&Ident]) -> Path {
segments,
}
}

#[derive(Debug, Default)]
struct Attrs {
safe: Option<bool>,
}

impl Attrs {
fn from_attrs(attributes: &[Attribute]) -> Result<Attrs> {
let mut attrs = Attrs::default();
let mut result = Ok(());
for attribute in attributes.iter().filter(|a| a.path().is_ident("safe")) {
result = match (result, parse_attr(attribute, &mut attrs)) {
(Ok(()), Err(e)) => Err(e),
(Err(mut e1), Err(e2)) => {
e1.combine(e2);
Err(e1)
}
(e, Ok(())) => e,
};
}

result.map(|_| attrs)
}

fn is_safe_and(&self, other: &Attrs) -> bool {
match (self.safe, other.safe) {
(Some(safe), Some(other)) => safe && other,
(None, Some(other)) => other,
(Some(safe), None) => safe,
(None, None) => false,
}
}

fn and(&mut self, other: &Attrs) {
match (self.safe, other.safe) {
(None, Some(other)) => self.safe = Some(other),
(Some(safe), Some(other)) => self.safe = Some(safe && other),
_ => {}
}
}
}

const INVALID_SAFE_ATTRIBUTE_MESSAGE: &str =
"invalid safe attribute, expected attribute in form #[safe(false)] or #[safe(true)]";

fn parse_attr(attribute: &Attribute, attrs: &mut Attrs) -> Result<()> {
let meta_list = attribute
.meta
.require_list()
.map_err(|_| Error::new(attribute.span(), INVALID_SAFE_ATTRIBUTE_MESSAGE))?;
let lit: syn::LitBool = meta_list
.parse_args()
.map_err(|_| Error::new(meta_list.span(), INVALID_SAFE_ATTRIBUTE_MESSAGE))?;
attrs.safe = Some(lit.value);

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn attrs_safe_requires_arg() {
let attr: Attribute = syn::parse_quote! {
#[safe]
};
assert!(
matches!(Attrs::from_attrs(&[attr]), Err(err) if err.to_string() == INVALID_SAFE_ATTRIBUTE_MESSAGE)
);
}

#[test]
fn attrs_safe_requires_bool() {
let attr: Attribute = syn::parse_quote! {
#[safe(false)]
};
assert!(!Attrs::from_attrs(&[attr]).unwrap().safe.unwrap());

let attr: Attribute = syn::parse_quote! {
#[safe(true)]
};
assert!(Attrs::from_attrs(&[attr]).unwrap().safe.unwrap());

let attr: Attribute = syn::parse_quote! {
#[safe(other)]
};
assert!(
matches!(Attrs::from_attrs(&[attr]), Err(err) if err.to_string() == INVALID_SAFE_ATTRIBUTE_MESSAGE)
);
}

#[test]
fn attrs_is_safe_and() {
let mut type_attrs = Attrs::default();
let mut field_attrs = Attrs::default();

// Both None
assert!(!type_attrs.is_safe_and(&field_attrs));

// None, Some(false)
field_attrs.safe = Some(false);
assert!(!type_attrs.is_safe_and(&field_attrs));

// None, Some(true)
field_attrs.safe = Some(true);
assert!(type_attrs.is_safe_and(&field_attrs));

// Some(false), Some(true)
type_attrs.safe = Some(false);
assert!(!type_attrs.is_safe_and(&field_attrs));

// Some(false), Some(false)
field_attrs.safe = Some(false);
assert!(!type_attrs.is_safe_and(&field_attrs));

// Some(true), Some(false)
type_attrs.safe = Some(true);
assert!(!type_attrs.is_safe_and(&field_attrs));

// Some(true), Some(true)
field_attrs.safe = Some(true);
assert!(type_attrs.is_safe_and(&field_attrs));

// Some(true), None
field_attrs.safe = None;
assert!(type_attrs.is_safe_and(&field_attrs));

// Some(false), None
type_attrs.safe = Some(false);
assert!(!type_attrs.is_safe_and(&field_attrs));
}

#[test]
fn attrs_and() {
let mut type_attrs = Attrs::default();
let mut enum_attrs = Attrs::default();

// Both None
type_attrs.and(&enum_attrs);
assert!(type_attrs.safe.is_none());

// None, Some(false)
enum_attrs.safe = Some(false);
type_attrs.and(&enum_attrs);
assert!(!type_attrs.safe.unwrap());

// None, Some(true)
type_attrs.safe = None;
enum_attrs.safe = Some(true);
type_attrs.and(&enum_attrs);
assert!(type_attrs.safe.unwrap());

// Some(false), Some(true)
type_attrs.safe = Some(false);
enum_attrs.safe = Some(true);
type_attrs.and(&enum_attrs);
assert!(!type_attrs.safe.unwrap());

// Some(true), Some(false)
type_attrs.safe = Some(true);
enum_attrs.safe = Some(false);
type_attrs.and(&enum_attrs);
assert!(!type_attrs.safe.unwrap());

// Some(true), Some(true)
type_attrs.safe = Some(true);
enum_attrs.safe = Some(true);
type_attrs.and(&enum_attrs);
assert!(type_attrs.safe.unwrap());

// Some(false), Some(false)
type_attrs.safe = Some(false);
enum_attrs.safe = Some(false);
type_attrs.and(&enum_attrs);
assert!(!type_attrs.safe.unwrap());

// Some(true), None
type_attrs.safe = Some(true);
enum_attrs.safe = None;
type_attrs.and(&enum_attrs);
assert!(type_attrs.safe.unwrap());

// Some(false), None
type_attrs.safe = Some(false);
enum_attrs.safe = None;
type_attrs.and(&enum_attrs);
assert!(!type_attrs.safe.unwrap());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
use typespec_client_core::fmt::SafeDebug;

#[derive(SafeDebug)]
pub struct Tuple(pub i32, pub &'static str);
#[safe(false)]
pub struct Tuple(#[safe(true)] pub i32, pub &'static str);

#[derive(SafeDebug)]
pub struct EmptyTuple();

#[derive(SafeDebug)]
pub struct Struct {
#[safe(true)]
pub a: i32,
pub b: &'static str,
}
Expand All @@ -26,6 +28,11 @@ pub enum Enum {
Unit,
Tuple(i32, &'static str),
EmptyTuple(),
Struct { a: i32, b: &'static str },
#[safe(true)]
Struct {
a: i32,
#[safe(false)]
b: &'static str,
},
EmptyStruct {},
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn safe_debug_empty_tuple() {
#[cfg_attr(not(feature = "debug"), test)]
fn safe_debug_struct() {
let x = Struct { a: 1, b: "foo" };
assert_eq!(format!("{x:?}"), r#"Struct { .. }"#);
assert_eq!(format!("{x:?}"), r#"Struct { a: 1, .. }"#);
}

#[test]
Expand Down Expand Up @@ -71,7 +71,7 @@ fn safe_debug_enum_empty_tuple() {
#[cfg_attr(not(feature = "debug"), test)]
fn safe_debug_enum_struct() {
let x = Enum::Struct { a: 1, b: "foo" };
assert_eq!(format!("{x:?}"), r#"Struct { .. }"#);
assert_eq!(format!("{x:?}"), r#"Struct { a: 1, .. }"#);
}

#[test]
Expand Down
Loading