Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 95 additions & 44 deletions patchable-macro/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//! 1. The companion patch struct (state struct).
//! 2. The `Patchable` trait implementation.

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use proc_macro_crate::{FoundCrate, crate_name};
use proc_macro2::{Span, TokenStream as TokenStream2};
Expand All @@ -20,9 +20,8 @@ use syn::{
};

#[derive(Clone, Copy, PartialEq, Eq)]
enum TypeUsage {
NotPatchable,
Patchable,
struct TypeUsage {
used_in_keep: bool,
}

pub(crate) struct MacroContext<'a> {
Expand All @@ -34,6 +33,8 @@ pub(crate) struct MacroContext<'a> {
fields: &'a Fields,
/// Mapping from preserved type to its usage flag.
preserved_types: HashMap<&'a Ident, TypeUsage>,
/// Patchable field types that should implement `Patchable`.
patchable_field_types: Vec<&'a Type>,
/// The list of actions to perform for each field when generating the `patch` method and the
/// state struct.
///
Expand Down Expand Up @@ -67,6 +68,7 @@ impl<'a> MacroContext<'a> {
}

let mut preserved_types: HashMap<&Ident, TypeUsage> = HashMap::new();
let mut patchable_field_types = Vec::new();
let mut field_actions = vec![];

let stateful_fields: Vec<&Field> = fields
Expand All @@ -84,25 +86,18 @@ impl<'a> MacroContext<'a> {
let field_type = &field.ty;

if has_patchable_attr(field) {
let Some(type_name) = get_abstract_simple_type_name(field_type) else {
return Err(syn::Error::new_spanned(
field_type,
"Only a simple generic type can be used", // TODO: remove this limit
));
};
// `Patchable` usage overrides `NotPatchable` usage.
preserved_types.insert(type_name, TypeUsage::Patchable);
patchable_field_types.push(field_type);

field_actions.push(FieldAction::Patch {
member,
ty: field_type,
});
} else {
for type_name in collect_used_simple_types(field_type) {
// Only mark as `NotPatchable` if not already marked as `Patchable`.
preserved_types
.entry(type_name)
.or_insert(TypeUsage::NotPatchable);
.and_modify(|usage| usage.used_in_keep = true)
.or_insert(TypeUsage { used_in_keep: true });
}
field_actions.push(FieldAction::Keep {
member,
Expand All @@ -119,6 +114,7 @@ impl<'a> MacroContext<'a> {
field_actions,
state_struct_name: quote::format_ident!("{}State", &input.ident),
crate_path: use_site_crate_path(),
patchable_field_types,
})
}

Expand All @@ -135,10 +131,25 @@ impl<'a> MacroContext<'a> {
let generic_params = quote! { <#(#state_generic_params),*> };

let mut bounds = Vec::new();
let patchable_generic_params = self.collect_patchable_generic_params();
for param in self.generics.type_params() {
if let Some(TypeUsage::Patchable) = self.preserved_types.get(&param.ident) {
let crate_root = &self.crate_path;
bounds.push(quote! { #param: #crate_root :: Patchable });
if patchable_generic_params.contains(&param.ident) {
bounds.push(quote! { #param: ::core::clone::Clone });
}
Comment on lines +134 to +138

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid forcing Clone on generics only used by patchable fields

This adds T: Clone for any generic referenced inside a #[patchable] field type. That over‑constrains cases like struct Outer<T> { #[patchable] inner: Wrapper<T> } where Wrapper<T>: Patchable does not require T: Clone (e.g., PhantomData wrappers or manual Patchable impls whose Patch doesn’t depend on T). The new bound will make such valid uses fail to compile even though Wrapper<T>: Patchable is satisfied.

Useful? React with 👍 / 👎.

}

let crate_root = &self.crate_path;
let mut seen_patchable_bounds = HashSet::new();
for field_type in &self.patchable_field_types {
let key = field_type.to_token_stream().to_string();
if seen_patchable_bounds.insert(key) {
bounds.push(quote! { #field_type: #crate_root :: Patchable });
if !self.is_generic_param_type(field_type) {
bounds.push(quote! {
for<'patchable_de> <#field_type as #crate_root :: Patchable>::Patch:
::serde::Deserialize<'patchable_de>
});
}
}
}
let where_clause = if bounds.is_empty() {
Expand Down Expand Up @@ -194,6 +205,7 @@ impl<'a> MacroContext<'a> {
}

fn generate_state_fields(&self) -> Vec<TokenStream2> {
let crate_root = &self.crate_path;
self.field_actions
.iter()
.map(|action| match action {
Expand All @@ -213,13 +225,23 @@ impl<'a> MacroContext<'a> {
member: FieldMember::Named(name),
ty,
} => {
quote! { #name : #ty :: Patch }
let patch_type = if self.is_generic_param_type(ty) {
quote! { #ty :: Patch }
} else {
quote! { <#ty as #crate_root :: Patchable>::Patch }
};
quote! { #name : #patch_type }
}
FieldAction::Patch {
member: FieldMember::Unnamed(_),
ty,
} => {
quote! { #ty :: Patch }
let patch_type = if self.is_generic_param_type(ty) {
quote! { #ty :: Patch }
} else {
quote! { <#ty as #crate_root :: Patchable>::Patch }
};
quote! { #patch_type }
}
})
.collect()
Expand Down Expand Up @@ -250,27 +272,70 @@ impl<'a> MacroContext<'a> {

fn collect_state_generics(&self) -> Vec<Ident> {
let mut generics = Vec::new();
let patchable_generic_params = self.collect_patchable_generic_params();
for param in self.generics.type_params() {
if self.preserved_types.contains_key(&param.ident) {
if self.preserved_types.contains_key(&param.ident)
|| patchable_generic_params.contains(&param.ident)
{
generics.push(param.ident.clone());
}
}
generics
}

fn collect_patchable_generic_params(&self) -> HashSet<&Ident> {
let mut patchable_generic_params = HashSet::new();
for field_type in &self.patchable_field_types {
for type_name in collect_used_simple_types(field_type) {
patchable_generic_params.insert(type_name);
}
}
patchable_generic_params
}

fn is_generic_param_type(&self, ty: &Type) -> bool {
let Type::Path(tp) = ty else {
return false;
};
if tp.qself.is_some() || tp.path.segments.len() != 1 {
return false;
}
let segment = &tp.path.segments[0];
if !matches!(segment.arguments, PathArguments::None) {
return false;
}
self.generics
.type_params()
.any(|param| param.ident == segment.ident)
}

fn build_bounds(&self) -> TokenStream2 {
let mut bounds = Vec::new();
let patchable_generic_params = self.collect_patchable_generic_params();
for param in self.generics.type_params() {
let t = &param.ident;
match self.preserved_types.get(t) {
Some(TypeUsage::Patchable) => {
let crate_root = &self.crate_path;
bounds.push(quote! { #t: #crate_root :: Patchable + ::core::clone::Clone });
}
Some(TypeUsage::NotPatchable) => {
bounds.push(quote! { #t: ::core::clone::Clone });
if patchable_generic_params.contains(t)
|| self
.preserved_types
.get(t)
.is_some_and(|usage| usage.used_in_keep)
{
bounds.push(quote! { #t: ::core::clone::Clone });
}
}

let crate_root = &self.crate_path;
let mut seen_patchable_bounds = HashSet::new();
for field_type in &self.patchable_field_types {
let key = field_type.to_token_stream().to_string();
if seen_patchable_bounds.insert(key) {
bounds.push(quote! { #field_type: #crate_root :: Patchable });
if !self.is_generic_param_type(field_type) {
bounds.push(quote! {
for<'patchable_de> <#field_type as #crate_root :: Patchable>::Patch:
::serde::Deserialize<'patchable_de>
});
}
None => {}
}
}

Expand Down Expand Up @@ -299,9 +364,10 @@ impl<'a> MacroContext<'a> {

fn build_associate_type_declaration(&self) -> TokenStream2 {
let mut args = Vec::new();
let patchable_generic_params = self.collect_patchable_generic_params();
for param in self.generics.type_params() {
let t = &param.ident;
if self.preserved_types.contains_key(t) {
if self.preserved_types.contains_key(t) || patchable_generic_params.contains(t) {
args.push(quote! { #t });
}
}
Expand Down Expand Up @@ -404,18 +470,3 @@ fn collect_used_simple_types(ty: &Type) -> Vec<&Ident> {
collector.visit_type(ty);
collector.used_simple_types
}

fn get_abstract_simple_type_name(t: &Type) -> Option<&Ident> {
match t {
Type::Path(tp) if !tp.path.segments.is_empty() => {
let last_segment = tp.path.segments.last()?;
// Ensure the path segment has no arguments (e.g., it's not `Vec<T>` or `Option<T>`).
if matches!(last_segment.arguments, PathArguments::None) {
Some(&last_segment.ident)
} else {
None
}
}
_ => None,
}
}
26 changes: 26 additions & 0 deletions patchable/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,32 @@ pub(crate) mod test {
assert_eq!(s, UnitStruct);
}

#[derive(Clone, Debug, Serialize, Deserialize, Patchable, PartialEq)]
struct Wrapper<T> {
value: T,
}

#[derive(Clone, Debug, Serialize, Deserialize, Patchable, PartialEq)]
struct Holder<T> {
#[patchable]
inner: Wrapper<T>,
}

#[test]
fn test_patchable_generic_field_type() -> anyhow::Result<()> {
let holder = Holder {
inner: Wrapper { value: 7u32 },
};
let mut target = holder.clone();

let state: String = serde_json::to_string(&holder)?;
let patch = serde_json::from_str(&state)?;

target.patch(patch);
assert_eq!(target, holder);
Ok(())
}

#[derive(Debug, PartialEq)]
struct FallibleStruct {
value: i32,
Expand Down