Skip to content

Commit 2e64626

Browse files
authored
Implement drop-safe struct_extensions builder (#42)
* Implement drop-safe struct_extensions builder * Add is_init method to builder * assume_init_mut -> into_assume_init_mut * assume_init_forget -> finish * Update docs to use updated method names * ensure we don't overflow with left shifts
1 parent 511daeb commit 2e64626

File tree

3 files changed

+499
-37
lines changed

3 files changed

+499
-37
lines changed

wincode-derive/src/schema_read.rs

Lines changed: 232 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ use {
88
ast::{Data, Fields, Style},
99
Error, FromDeriveInput, Result,
1010
},
11-
proc_macro2::TokenStream,
11+
proc_macro2::{Span, TokenStream},
1212
quote::{format_ident, quote},
13-
syn::{parse_quote, DeriveInput, GenericParam, Generics, Type},
13+
syn::{parse_quote, DeriveInput, GenericParam, Generics, LitInt, LitStr, Path, Type},
1414
};
1515

1616
fn impl_struct(
@@ -185,7 +185,7 @@ fn impl_struct(
185185
/// ```
186186
///
187187
/// We cannot do this for enums, given the lack of facilities for placement initialization.
188-
fn impl_struct_extensions(args: &SchemaArgs) -> Result<TokenStream> {
188+
fn impl_struct_extensions(args: &SchemaArgs, crate_name: &Path) -> Result<TokenStream> {
189189
if !args.struct_extensions {
190190
return Ok(quote! {});
191191
}
@@ -205,6 +205,7 @@ fn impl_struct_extensions(args: &SchemaArgs) -> Result<TokenStream> {
205205
let dst = get_src_dst(args);
206206
let impl_generics = append_de_lifetime(&args.generics);
207207
let (_, ty_generics, where_clause) = args.generics.split_for_impl();
208+
let builder_ident = format_ident!("{struct_ident}UninitBuilder");
208209

209210
let helpers = fields.iter().enumerate().map(|(i, field)| {
210211
let ty = field.ty.with_lifetime("de");
@@ -214,8 +215,12 @@ fn impl_struct_extensions(args: &SchemaArgs) -> Result<TokenStream> {
214215
let uninit_mut_ident = format_ident!("uninit_{}_mut", ident_string);
215216
let read_field_ident = format_ident!("read_{}", ident_string);
216217
let write_uninit_field_ident = format_ident!("write_uninit_{}", ident_string);
218+
let deprecated_note = LitStr::new(
219+
&format!("Use `{builder_ident}` builder methods instead"),
220+
Span::call_site(),
221+
);
217222
let field_projection_type = if args.from.is_some() {
218-
// If the user is defining a mapping type, we need the type system to resolve the
223+
// If the user is defining a mapping type, we need the type system to resolve the
219224
// projection destination.
220225
quote! { <#ty as SchemaRead<'de>>::Dst }
221226
} else {
@@ -225,27 +230,244 @@ fn impl_struct_extensions(args: &SchemaArgs) -> Result<TokenStream> {
225230
};
226231
quote! {
227232
#[inline(always)]
233+
#[deprecated(since = "0.2.2", note = #deprecated_note)]
228234
#vis fn #uninit_mut_ident(dst: &mut MaybeUninit<#dst>) -> &mut MaybeUninit<#field_projection_type> {
229235
unsafe { &mut *(&raw mut (*dst.as_mut_ptr()).#ident).cast() }
230236
}
231237

232238
#[inline(always)]
239+
#[deprecated(since = "0.2.2", note = #deprecated_note)]
233240
#vis fn #read_field_ident(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<#dst>) -> ReadResult<()> {
234241
<#target as SchemaRead<'de>>::read(reader, Self::#uninit_mut_ident(dst))
235242
}
236243

237244
#[inline(always)]
245+
#[deprecated(since = "0.2.2", note = #deprecated_note)]
238246
#vis fn #write_uninit_field_ident(val: #field_projection_type, dst: &mut MaybeUninit<#dst>) {
239247
Self::#uninit_mut_ident(dst).write(val);
240248
}
241249
}
242250
});
243251

244-
Ok(quote!(
245-
impl #impl_generics #struct_ident #ty_generics #where_clause {
246-
#(#helpers)*
252+
// We modify the generics to add a lifetime parameter for the inner `MaybeUninit` struct.
253+
let mut builder_generics = args.generics.clone();
254+
// Add the lifetime for the inner `&mut MaybeUninit` struct.
255+
builder_generics
256+
.params
257+
.push(GenericParam::Lifetime(parse_quote!('_wincode_inner)));
258+
259+
let builder_dst = get_src_dst_fully_qualified(args);
260+
261+
let (builder_impl_generics, builder_ty_generics, builder_where_clause) =
262+
builder_generics.split_for_impl();
263+
// Determine how many bits are needed to track the initialization state of the fields.
264+
let (builder_bit_set_ty, builder_bit_set_bits): (Type, u32) = match fields.len() {
265+
len if len <= 8 => (parse_quote!(u8), u8::BITS),
266+
len if len <= 16 => (parse_quote!(u16), u16::BITS),
267+
len if len <= 32 => (parse_quote!(u32), u32::BITS),
268+
len if len <= 64 => (parse_quote!(u64), u64::BITS),
269+
len if len <= 128 => (parse_quote!(u128), u128::BITS),
270+
_ => {
271+
return Err(Error::custom(
272+
"`struct_extensions` is only supported for structs with up to 128 fields",
273+
))
274+
}
275+
};
276+
let builder_struct_decl = {
277+
// `split_for_impl` will strip default type and const parameters, so we collect them manually
278+
// to preserve the declarations on the original struct.
279+
let generic_type_params = builder_generics.type_params();
280+
let generic_lifetimes = builder_generics.lifetimes();
281+
let generic_const = builder_generics.const_params();
282+
let where_clause = builder_generics.where_clause.as_ref();
283+
quote! {
284+
/// A helper struct that provides convenience methods for reading and writing to a `MaybeUninit` struct
285+
/// with a bit-set tracking the initialization state of the fields.
286+
///
287+
/// The builder will drop all initialized fields in reverse order on drop. When the struct is fully initialized,
288+
/// you **must** call `finish` or `into_assume_init_mut` to forget the builder. Otherwise, all the
289+
/// initialized fields will be dropped when the builder is dropped.
290+
#[must_use]
291+
#vis struct #builder_ident < #(#generic_lifetimes,)* #(#generic_const,)* #(#generic_type_params,)* > #where_clause {
292+
inner: &'_wincode_inner mut core::mem::MaybeUninit<#builder_dst>,
293+
init_set: #builder_bit_set_ty,
294+
}
295+
}
296+
};
297+
298+
let builder_drop_impl = {
299+
// Drop all initialized fields in reverse order.
300+
let drops = fields.iter().rev().enumerate().map(|(index, field)| {
301+
// Compute the actual index relative to the reversed iterator.
302+
let real_index = fields.len() - 1 - index;
303+
let field_ident = field.struct_member_ident(real_index);
304+
// The corresponding bit for the field.
305+
let bit_set_index = LitInt::new(&(1u128 << real_index).to_string(), Span::call_site());
306+
quote! {
307+
if self.init_set & #bit_set_index != 0 {
308+
// SAFETY: We are dropping an initialized field.
309+
unsafe {
310+
ptr::drop_in_place(&raw mut (*dst_ptr).#field_ident);
311+
}
312+
}
313+
}
314+
});
315+
quote! {
316+
impl #builder_impl_generics Drop for #builder_ident #builder_ty_generics #builder_where_clause {
317+
fn drop(&mut self) {
318+
let dst_ptr = self.inner.as_mut_ptr();
319+
#(#drops)*
320+
}
321+
}
247322
}
248-
))
323+
};
324+
325+
let builder_impl = {
326+
let is_fully_init_mask = if fields.len() as u32 == builder_bit_set_bits {
327+
quote!(#builder_bit_set_ty::MAX)
328+
} else {
329+
let field_bits = LitInt::new(&fields.len().to_string(), Span::call_site());
330+
quote!(((1 as #builder_bit_set_ty) << #field_bits) - 1)
331+
};
332+
333+
quote! {
334+
impl #builder_impl_generics #builder_ident #builder_ty_generics #builder_where_clause {
335+
#vis const fn from_maybe_uninit_mut(inner: &'_wincode_inner mut MaybeUninit<#builder_dst>) -> Self {
336+
Self {
337+
inner,
338+
init_set: 0,
339+
}
340+
}
341+
342+
/// Check if the builder is fully initialized.
343+
///
344+
/// This will check if all field initialization bits are set.
345+
#[inline]
346+
#vis const fn is_init(&self) -> bool {
347+
self.init_set == #is_fully_init_mask
348+
}
349+
350+
/// Assume the builder is fully initialized, and return a mutable reference to the inner `MaybeUninit` struct.
351+
///
352+
/// The builder will be forgotten, so the drop logic will not longer run.
353+
///
354+
/// # Safety
355+
///
356+
/// Calling this when the content is not yet fully initialized causes undefined behavior: it is up to the caller
357+
/// to guarantee that the `MaybeUninit<T>` really is in an initialized state.
358+
#[inline]
359+
#vis const unsafe fn into_assume_init_mut(mut self) -> &'_wincode_inner mut #builder_dst {
360+
// SAFETY: reference lives beyond the scope of the builder, and builder is forgotten.
361+
let inner = unsafe { ptr::read(&mut self.inner) };
362+
mem::forget(self);
363+
// SAFETY: Caller asserts the `MaybeUninit<T>` is in an initialized state.
364+
unsafe {
365+
inner.assume_init_mut()
366+
}
367+
}
368+
369+
/// Forget the builder, disabling the drop logic.
370+
#[inline]
371+
#vis const fn finish(self) {
372+
mem::forget(self);
373+
}
374+
}
375+
}
376+
};
377+
378+
// Generate the helper methods for the builder.
379+
let builder_helpers = fields.iter().enumerate().map(|(i, field)| {
380+
let target = field.target_resolved();
381+
let target_reader_bound = target.with_lifetime("de");
382+
let ident = field.struct_member_ident(i);
383+
let ident_string = field.struct_member_ident_to_string(i);
384+
let uninit_mut_ident = format_ident!("uninit_{ident_string}_mut");
385+
let read_field_ident = format_ident!("read_{ident_string}");
386+
let write_uninit_field_ident = format_ident!("write_{ident_string}");
387+
let assume_init_field_ident = format_ident!("assume_init_{ident_string}");
388+
let init_with_field_ident = format_ident!("init_{ident_string}_with");
389+
390+
// The bit index for the field.
391+
let index_bit = LitInt::new(&(1u128 << i).to_string(), Span::call_site());
392+
let set_index_bit = quote! {
393+
self.init_set |= #index_bit;
394+
};
395+
396+
quote! {
397+
/// Get a mutable reference to the maybe uninitialized field.
398+
#[inline]
399+
#vis const fn #uninit_mut_ident(&mut self) -> &mut MaybeUninit<#target> {
400+
// SAFETY:
401+
// - `self.inner` is a valid reference to a `MaybeUninit<#builder_dst>`.
402+
// - We return the field as `&mut MaybeUninit<#target>`, so
403+
// the field is never exposed as initialized.
404+
unsafe { &mut *(&raw mut (*self.inner.as_mut_ptr()).#ident).cast() }
405+
}
406+
407+
/// Write a value to the maybe uninitialized field.
408+
#[inline]
409+
#vis const fn #write_uninit_field_ident(&mut self, val: #target) -> &mut Self {
410+
self.#uninit_mut_ident().write(val);
411+
#set_index_bit
412+
self
413+
}
414+
415+
/// Read a value from the reader into the maybe uninitialized field.
416+
#[inline]
417+
#vis fn #read_field_ident <'de>(&mut self, reader: &mut impl Reader<'de>) -> ReadResult<&mut Self> {
418+
// SAFETY:
419+
// - `self.inner` is a valid reference to a `MaybeUninit<#builder_dst>`.
420+
// - We return the field as `&mut MaybeUninit<#target>`, so
421+
// the field is never exposed as initialized.
422+
let proj = unsafe { &mut *(&raw mut (*self.inner.as_mut_ptr()).#ident).cast() };
423+
<#target_reader_bound as SchemaRead<'de>>::read(reader, proj)?;
424+
#set_index_bit
425+
Ok(self)
426+
}
427+
428+
/// Initialize the field with a given initializer function.
429+
///
430+
/// # Safety
431+
///
432+
/// The caller must guarantee that the initializer function fully initializes the field.
433+
#[inline]
434+
#vis unsafe fn #init_with_field_ident(&mut self, mut initializer: impl FnMut(&mut MaybeUninit<#target>) -> ReadResult<()>) -> ReadResult<&mut Self> {
435+
initializer(self.#uninit_mut_ident())?;
436+
#set_index_bit
437+
Ok(self)
438+
}
439+
440+
/// Mark the field as initialized.
441+
///
442+
/// # Safety
443+
///
444+
/// Caller must guarantee the field has been fully initialized prior to calling this.
445+
#[inline]
446+
#vis const unsafe fn #assume_init_field_ident(&mut self) -> &mut Self {
447+
#set_index_bit
448+
self
449+
}
450+
}
451+
});
452+
453+
Ok(quote! {
454+
const _: () = {
455+
use {
456+
core::{mem::{MaybeUninit, self}, ptr, marker::PhantomData},
457+
#crate_name::{SchemaRead, ReadResult, TypeMeta, io::Reader, error,},
458+
};
459+
impl #impl_generics #struct_ident #ty_generics #where_clause {
460+
#(#helpers)*
461+
}
462+
#builder_drop_impl
463+
#builder_impl
464+
465+
impl #builder_impl_generics #builder_ident #builder_ty_generics #builder_where_clause {
466+
#(#builder_helpers)*
467+
}
468+
};
469+
#builder_struct_decl
470+
})
249471
}
250472

251473
fn impl_enum(
@@ -389,7 +611,7 @@ pub(crate) fn generate(input: DeriveInput) -> Result<TokenStream> {
389611
let crate_name = get_crate_name(&args);
390612
let src_dst = get_src_dst(&args);
391613
let field_suppress = suppress_unused_fields(&args);
392-
let struct_extensions = impl_struct_extensions(&args)?;
614+
let struct_extensions = impl_struct_extensions(&args, &crate_name)?;
393615

394616
let (read_impl, type_meta_impl) = match &args.data {
395617
Data::Struct(fields) => {
@@ -426,8 +648,8 @@ pub(crate) fn generate(input: DeriveInput) -> Result<TokenStream> {
426648
Ok(())
427649
}
428650
}
429-
#struct_extensions
430651
};
652+
#struct_extensions
431653
#field_suppress
432654
})
433655
}

0 commit comments

Comments
 (0)