diff --git a/patchable-macro/src/context.rs b/patchable-macro/src/context.rs index 5edf22f..a3b06b0 100644 --- a/patchable-macro/src/context.rs +++ b/patchable-macro/src/context.rs @@ -7,7 +7,7 @@ //! 1. The companion patch struct (state struct). //! 2. The `Patchable` trait implementation. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use proc_macro_crate::{FoundCrate, crate_name}; use proc_macro2::{Span, TokenStream as TokenStream2}; @@ -20,9 +20,8 @@ use syn::{ }; #[derive(Clone, Copy, PartialEq, Eq)] -enum TypeUsage { - NotPatchable, - Patchable, +struct TypeUsage { + used_in_keep: bool, } pub(crate) struct MacroContext<'a> { @@ -34,6 +33,8 @@ pub(crate) struct MacroContext<'a> { fields: &'a Fields, /// Mapping from preserved type to its usage flag. preserved_types: HashMap<&'a Ident, TypeUsage>, + /// Patchable field types that should implement `Patchable`. + patchable_field_types: Vec<&'a Type>, /// The list of actions to perform for each field when generating the `patch` method and the /// state struct. /// @@ -67,6 +68,7 @@ impl<'a> MacroContext<'a> { } let mut preserved_types: HashMap<&Ident, TypeUsage> = HashMap::new(); + let mut patchable_field_types = Vec::new(); let mut field_actions = vec![]; let stateful_fields: Vec<&Field> = fields @@ -84,14 +86,7 @@ impl<'a> MacroContext<'a> { let field_type = &field.ty; if has_patchable_attr(field) { - let Some(type_name) = get_abstract_simple_type_name(field_type) else { - return Err(syn::Error::new_spanned( - field_type, - "Only a simple generic type can be used", // TODO: remove this limit - )); - }; - // `Patchable` usage overrides `NotPatchable` usage. - preserved_types.insert(type_name, TypeUsage::Patchable); + patchable_field_types.push(field_type); field_actions.push(FieldAction::Patch { member, @@ -99,10 +94,10 @@ impl<'a> MacroContext<'a> { }); } else { for type_name in collect_used_simple_types(field_type) { - // Only mark as `NotPatchable` if not already marked as `Patchable`. preserved_types .entry(type_name) - .or_insert(TypeUsage::NotPatchable); + .and_modify(|usage| usage.used_in_keep = true) + .or_insert(TypeUsage { used_in_keep: true }); } field_actions.push(FieldAction::Keep { member, @@ -119,6 +114,7 @@ impl<'a> MacroContext<'a> { field_actions, state_struct_name: quote::format_ident!("{}State", &input.ident), crate_path: use_site_crate_path(), + patchable_field_types, }) } @@ -135,10 +131,25 @@ impl<'a> MacroContext<'a> { let generic_params = quote! { <#(#state_generic_params),*> }; let mut bounds = Vec::new(); + let patchable_generic_params = self.collect_patchable_generic_params(); for param in self.generics.type_params() { - if let Some(TypeUsage::Patchable) = self.preserved_types.get(¶m.ident) { - let crate_root = &self.crate_path; - bounds.push(quote! { #param: #crate_root :: Patchable }); + if patchable_generic_params.contains(¶m.ident) { + bounds.push(quote! { #param: ::core::clone::Clone }); + } + } + + let crate_root = &self.crate_path; + let mut seen_patchable_bounds = HashSet::new(); + for field_type in &self.patchable_field_types { + let key = field_type.to_token_stream().to_string(); + if seen_patchable_bounds.insert(key) { + bounds.push(quote! { #field_type: #crate_root :: Patchable }); + if !self.is_generic_param_type(field_type) { + bounds.push(quote! { + for<'patchable_de> <#field_type as #crate_root :: Patchable>::Patch: + ::serde::Deserialize<'patchable_de> + }); + } } } let where_clause = if bounds.is_empty() { @@ -194,6 +205,7 @@ impl<'a> MacroContext<'a> { } fn generate_state_fields(&self) -> Vec { + let crate_root = &self.crate_path; self.field_actions .iter() .map(|action| match action { @@ -213,13 +225,23 @@ impl<'a> MacroContext<'a> { member: FieldMember::Named(name), ty, } => { - quote! { #name : #ty :: Patch } + let patch_type = if self.is_generic_param_type(ty) { + quote! { #ty :: Patch } + } else { + quote! { <#ty as #crate_root :: Patchable>::Patch } + }; + quote! { #name : #patch_type } } FieldAction::Patch { member: FieldMember::Unnamed(_), ty, } => { - quote! { #ty :: Patch } + let patch_type = if self.is_generic_param_type(ty) { + quote! { #ty :: Patch } + } else { + quote! { <#ty as #crate_root :: Patchable>::Patch } + }; + quote! { #patch_type } } }) .collect() @@ -250,27 +272,70 @@ impl<'a> MacroContext<'a> { fn collect_state_generics(&self) -> Vec { let mut generics = Vec::new(); + let patchable_generic_params = self.collect_patchable_generic_params(); for param in self.generics.type_params() { - if self.preserved_types.contains_key(¶m.ident) { + if self.preserved_types.contains_key(¶m.ident) + || patchable_generic_params.contains(¶m.ident) + { generics.push(param.ident.clone()); } } generics } + fn collect_patchable_generic_params(&self) -> HashSet<&Ident> { + let mut patchable_generic_params = HashSet::new(); + for field_type in &self.patchable_field_types { + for type_name in collect_used_simple_types(field_type) { + patchable_generic_params.insert(type_name); + } + } + patchable_generic_params + } + + fn is_generic_param_type(&self, ty: &Type) -> bool { + let Type::Path(tp) = ty else { + return false; + }; + if tp.qself.is_some() || tp.path.segments.len() != 1 { + return false; + } + let segment = &tp.path.segments[0]; + if !matches!(segment.arguments, PathArguments::None) { + return false; + } + self.generics + .type_params() + .any(|param| param.ident == segment.ident) + } + fn build_bounds(&self) -> TokenStream2 { let mut bounds = Vec::new(); + let patchable_generic_params = self.collect_patchable_generic_params(); for param in self.generics.type_params() { let t = ¶m.ident; - match self.preserved_types.get(t) { - Some(TypeUsage::Patchable) => { - let crate_root = &self.crate_path; - bounds.push(quote! { #t: #crate_root :: Patchable + ::core::clone::Clone }); - } - Some(TypeUsage::NotPatchable) => { - bounds.push(quote! { #t: ::core::clone::Clone }); + if patchable_generic_params.contains(t) + || self + .preserved_types + .get(t) + .is_some_and(|usage| usage.used_in_keep) + { + bounds.push(quote! { #t: ::core::clone::Clone }); + } + } + + let crate_root = &self.crate_path; + let mut seen_patchable_bounds = HashSet::new(); + for field_type in &self.patchable_field_types { + let key = field_type.to_token_stream().to_string(); + if seen_patchable_bounds.insert(key) { + bounds.push(quote! { #field_type: #crate_root :: Patchable }); + if !self.is_generic_param_type(field_type) { + bounds.push(quote! { + for<'patchable_de> <#field_type as #crate_root :: Patchable>::Patch: + ::serde::Deserialize<'patchable_de> + }); } - None => {} } } @@ -299,9 +364,10 @@ impl<'a> MacroContext<'a> { fn build_associate_type_declaration(&self) -> TokenStream2 { let mut args = Vec::new(); + let patchable_generic_params = self.collect_patchable_generic_params(); for param in self.generics.type_params() { let t = ¶m.ident; - if self.preserved_types.contains_key(t) { + if self.preserved_types.contains_key(t) || patchable_generic_params.contains(t) { args.push(quote! { #t }); } } @@ -404,18 +470,3 @@ fn collect_used_simple_types(ty: &Type) -> Vec<&Ident> { collector.visit_type(ty); collector.used_simple_types } - -fn get_abstract_simple_type_name(t: &Type) -> Option<&Ident> { - match t { - Type::Path(tp) if !tp.path.segments.is_empty() => { - let last_segment = tp.path.segments.last()?; - // Ensure the path segment has no arguments (e.g., it's not `Vec` or `Option`). - if matches!(last_segment.arguments, PathArguments::None) { - Some(&last_segment.ident) - } else { - None - } - } - _ => None, - } -} diff --git a/patchable/src/lib.rs b/patchable/src/lib.rs index ecf948c..1b183dc 100644 --- a/patchable/src/lib.rs +++ b/patchable/src/lib.rs @@ -249,6 +249,32 @@ pub(crate) mod test { assert_eq!(s, UnitStruct); } + #[derive(Clone, Debug, Serialize, Deserialize, Patchable, PartialEq)] + struct Wrapper { + value: T, + } + + #[derive(Clone, Debug, Serialize, Deserialize, Patchable, PartialEq)] + struct Holder { + #[patchable] + inner: Wrapper, + } + + #[test] + fn test_patchable_generic_field_type() -> anyhow::Result<()> { + let holder = Holder { + inner: Wrapper { value: 7u32 }, + }; + let mut target = holder.clone(); + + let state: String = serde_json::to_string(&holder)?; + let patch = serde_json::from_str(&state)?; + + target.patch(patch); + assert_eq!(target, holder); + Ok(()) + } + #[derive(Debug, PartialEq)] struct FallibleStruct { value: i32,