diff --git a/sqlx-macros-core/src/derives/attributes.rs b/sqlx-macros-core/src/derives/attributes.rs index 6109863833..bf84df4474 100644 --- a/sqlx-macros-core/src/derives/attributes.rs +++ b/sqlx-macros-core/src/derives/attributes.rs @@ -59,6 +59,7 @@ pub struct SqlxContainerAttributes { pub repr: Option, pub no_pg_array: bool, pub default: bool, + pub try_from: Option, } pub enum JsonAttribute { @@ -82,6 +83,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result syn::Result()?; + let val: LitStr = meta.input.parse()?; + try_set!(try_from, val.parse()?, val); } else if meta.path.is_ident("rename_all") { meta.input.parse::()?; let lit: LitStr = meta.input.parse()?; @@ -140,6 +146,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result syn::Result { let attrs = parse_container_attributes(&input.attrs)?; + if let Some(try_from) = &attrs.try_from { + return expand_derive_decode_try_from(input, try_from); + } + match &input.data { Data::Struct(DataStruct { fields, .. }) if fields.len() == 1 && (matches!(fields, Fields::Unnamed(_)) || attrs.transparent) => @@ -46,6 +50,49 @@ pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { } } +fn expand_derive_decode_try_from( + input: &DeriveInput, + try_from: &syn::Type, +) -> syn::Result { + let ident = &input.ident; + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + let mut generics = generics.clone(); + generics + .params + .insert(0, parse_quote!(DB: ::sqlx::Database)); + generics.params.insert(0, parse_quote!('r)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#try_from: ::sqlx::decode::Decode<'r, DB>)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ident #ty_generics: ::std::convert::TryFrom<#try_from>)); + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + #[automatically_derived] + impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { + fn decode( + value: ::ValueRef<'r>, + ) -> ::std::result::Result< + Self, + ::std::boxed::Box< + dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, + >, + > { + let value = <#try_from as ::sqlx::decode::Decode<'r, DB>>::decode(value)?; + <#ident #ty_generics as ::std::convert::TryFrom<#try_from>>::try_from(value) + .map_err(|e| ::sqlx::__spec_error!(e)) + } + } + )) +} + fn expand_derive_decode_transparent( input: &DeriveInput, field: &Field, diff --git a/sqlx-macros-core/src/derives/type.rs b/sqlx-macros-core/src/derives/type.rs index a66229287e..3e9a84a62f 100644 --- a/sqlx-macros-core/src/derives/type.rs +++ b/sqlx-macros-core/src/derives/type.rs @@ -12,6 +12,10 @@ use syn::{ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { let attrs = parse_container_attributes(&input.attrs)?; + if let Some(try_from) = &attrs.try_from { + return expand_derive_has_sql_type_try_from(input, try_from, attrs.no_pg_array); + } + match &input.data { // Newtype structs: // struct Foo(i32); @@ -52,6 +56,66 @@ pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { } } +fn expand_derive_has_sql_type_try_from( + input: &DeriveInput, + try_from: &syn::Type, + no_pg_array: bool, +) -> syn::Result { + let ident = &input.ident; + + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + let mut generics = generics.clone(); + let mut array_generics = generics.clone(); + + generics + .params + .insert(0, parse_quote!(DB: ::sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#try_from: ::sqlx::Type)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + array_generics + .make_where_clause() + .predicates + .push(parse_quote!(#try_from: ::sqlx::postgres::PgHasArrayType)); + let (array_impl_generics, _, array_where_clause) = array_generics.split_for_impl(); + + let mut tokens = quote!( + #[automatically_derived] + impl #impl_generics ::sqlx::Type for #ident #ty_generics #where_clause { + fn type_info() -> DB::TypeInfo { + <#try_from as ::sqlx::Type>::type_info() + } + + fn compatible(ty: &DB::TypeInfo) -> ::std::primitive::bool { + <#try_from as ::sqlx::Type>::compatible(ty) + } + } + ); + + if cfg!(feature = "postgres") && !no_pg_array { + tokens.extend(quote!( + #[automatically_derived] + impl #array_impl_generics ::sqlx::postgres::PgHasArrayType for #ident #ty_generics + #array_where_clause { + fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { + <#try_from as ::sqlx::postgres::PgHasArrayType>::array_type_info() + } + + fn array_compatible(ty: &::sqlx::postgres::PgTypeInfo) -> ::std::primitive::bool { + <#try_from as ::sqlx::postgres::PgHasArrayType>::array_compatible(ty) + } + } + )); + } + + Ok(tokens) +} + fn expand_derive_has_sql_type_transparent( input: &DeriveInput, field: &Field, diff --git a/tests/sqlite/derives.rs b/tests/sqlite/derives.rs index 3491ab8539..10e07fbf8e 100644 --- a/tests/sqlite/derives.rs +++ b/tests/sqlite/derives.rs @@ -32,3 +32,20 @@ test_type!(transparent_named(Sqlite, "0" == TransparentNamed { field: 0 }, "23523" == TransparentNamed { field: 23523 }, )); + +#[derive(PartialEq, Eq, Debug, sqlx::Type)] +#[sqlx(try_from = "i64")] +struct TryFromI64(i64); + +impl TryFrom for TryFromI64 { + type Error = std::num::TryFromIntError; + + fn try_from(value: i64) -> Result { + Ok(Self(i32::try_from(value)? as i64)) + } +} + +test_type!(try_from_i64(Sqlite, + "0" == TryFromI64(0), + "23523" == TryFromI64(23523), +));