From d8fe095c746041a1f22c8b39adfc5706b3645b33 Mon Sep 17 00:00:00 2001 From: WGH Date: Tue, 27 Feb 2024 19:35:49 +0300 Subject: [PATCH] Add option to prefer text unmarshaler when data type is String This library currently always prefers to use encoding.BinaryUnmashaler when it's implemented by the target type. This might lead to problems when the type also has a text representation implemented as encoding.TextUnmarshaler. Consider netip.Addr from stdlib as example, which implements both. This library won't be able decode MessagePack containing a string "192.0.2.1" into *netip.Addr because it will attempt to use encoding.BinaryUnmashaler which doesn't expect text representation. Fortunately, MessagePack has distinct string and binary types, so we can check the source data type before choosing the interface to use. This commit changes the behaviour of decoder as follows. When 1) target Go data type implements both BinaryUnmashaler and TextUnmarshaler 2) source MessagePack data type is a string TextUnmarshaler will be preferred over BinaryUnmashaler. This feature is gated behind a Decoder option, because it is potentially backward-incompatible change. See https://github.com/vmihailenco/msgpack/issues/370 --- decode.go | 15 +++++++++++ decode_value.go | 39 +++++++++++++++++++++++++--- types_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 4 deletions(-) diff --git a/decode.go b/decode.go index ea645aa..b381c3b 100644 --- a/decode.go +++ b/decode.go @@ -24,6 +24,7 @@ const ( disallowUnknownFieldsFlag usePreallocateValues disableAllocLimitFlag + preferTextUnmarshalerForString ) type bufReader interface { @@ -184,6 +185,20 @@ func (d *Decoder) DisableAllocLimit(on bool) { } } +// PreferTextUnmarshalerForString makes the decoder prefer [encoding.TextUnmarshaler] +// over [encoding.BinaryUnmarshaler] when both are implemented, and source +// MessagePack data is a String (as opposed to Binary). +// +// If this option is not enabled, [encoding.BinaryUnmarshaler] will be preferred +// instead, regardless of MessagePack data type. +func (d *Decoder) PreferTextUnmarshalerForString(on bool) { + if on { + d.flags |= preferTextUnmarshalerForString + } else { + d.flags &= ^preferTextUnmarshalerForString + } +} + // Buffered returns a reader of the data remaining in the Decoder's buffer. // The reader is valid until the next call to Decode. func (d *Decoder) Buffered() io.Reader { diff --git a/decode_value.go b/decode_value.go index c44a674..e1b5b45 100644 --- a/decode_value.go +++ b/decode_value.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "reflect" + + "github.com/vmihailenco/msgpack/v5/msgpcode" ) var ( @@ -70,10 +72,16 @@ func _getDecoder(typ reflect.Type) decoderFunc { if typ.Implements(unmarshalerType) { return nilAwareDecoder(typ, unmarshalValue) } - if typ.Implements(binaryUnmarshalerType) { + + implementsBinaryUnmarshaler := typ.Implements(binaryUnmarshalerType) + implementsTextUnmarshaler := typ.Implements(textUnmarshalerType) + if implementsBinaryUnmarshaler && implementsTextUnmarshaler { + return nilAwareDecoder(typ, unmarshalBinaryOrTextValue) + } + if implementsBinaryUnmarshaler { return nilAwareDecoder(typ, unmarshalBinaryValue) } - if typ.Implements(textUnmarshalerType) { + if implementsTextUnmarshaler { return nilAwareDecoder(typ, unmarshalTextValue) } @@ -86,10 +94,15 @@ func _getDecoder(typ reflect.Type) decoderFunc { if ptr.Implements(unmarshalerType) { return addrDecoder(nilAwareDecoder(typ, unmarshalValue)) } - if ptr.Implements(binaryUnmarshalerType) { + implementsBinaryUnmarshaler := ptr.Implements(binaryUnmarshalerType) + implementsTextUnmarshaler := ptr.Implements(textUnmarshalerType) + if implementsBinaryUnmarshaler && implementsTextUnmarshaler { + return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryOrTextValue)) + } + if implementsBinaryUnmarshaler { return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue)) } - if ptr.Implements(textUnmarshalerType) { + if implementsTextUnmarshaler { return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue)) } } @@ -249,3 +262,21 @@ func unmarshalTextValue(d *Decoder, v reflect.Value) error { unmarshaler := v.Interface().(encoding.TextUnmarshaler) return unmarshaler.UnmarshalText(data) } + +func unmarshalBinaryOrTextValue(d *Decoder, v reflect.Value) error { + useText := false + if d.flags&preferTextUnmarshalerForString != 0 { + code, err := d.PeekCode() + if err != nil { + return err + } + if msgpcode.IsString(code) { + useText = true + } + } + if useText { + return unmarshalTextValue(d, v) + } else { + return unmarshalBinaryValue(d, v) + } +} diff --git a/types_test.go b/types_test.go index f8bfda1..384d6e5 100644 --- a/types_test.go +++ b/types_test.go @@ -2,6 +2,8 @@ package msgpack_test import ( "bytes" + "encoding" + "encoding/binary" "encoding/hex" "fmt" "math" @@ -427,6 +429,8 @@ type typeTest struct { wantnil bool wantzero bool wanted interface{} + + preferTextUnmarshalerForString bool } func (t typeTest) String() string { @@ -442,6 +446,36 @@ func (t *typeTest) requireErr(err error, s string) { } } +type binaryTextType uint32 + +// UnmarshalText implements encoding.TextUnmarshaler +func (v *binaryTextType) UnmarshalText(text []byte) error { + var b [4]byte + n, err := hex.Decode(b[:], text) + if err != nil { + return err + } + if n != 4 { + return fmt.Errorf("invalid length %d", n) + } + *v = binaryTextType(binary.BigEndian.Uint32(b[:])) + return nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (v *binaryTextType) UnmarshalBinary(data []byte) error { + if n := len(data); n != 4 { + return fmt.Errorf("invalid length %d", n) + } + *v = binaryTextType(binary.BigEndian.Uint32(data)) + return nil +} + +var ( + _ encoding.TextUnmarshaler = new(binaryTextType) + _ encoding.BinaryUnmarshaler = new(binaryTextType) +) + var ( intSlice = make([]int, 0, 3) repoURL, _ = url.Parse("https://github.com/vmihailenco/msgpack") @@ -622,6 +656,36 @@ var ( }, {in: big.NewInt(123), out: new(big.Int)}, + + { + in: "deadbeef", + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + decErr: "invalid length 8", + + preferTextUnmarshalerForString: false, + }, + { + in: "deadbeef", + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + + preferTextUnmarshalerForString: true, + }, + { + in: []byte{0xde, 0xad, 0xbe, 0xef}, + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + + preferTextUnmarshalerForString: false, + }, + { + in: []byte{0xde, 0xad, 0xbe, 0xef}, + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + + preferTextUnmarshalerForString: true, + }, } ) @@ -655,6 +719,9 @@ func TestTypes(t *testing.T) { } dec := msgpack.NewDecoder(&buf) + if test.preferTextUnmarshalerForString { + dec.PreferTextUnmarshalerForString(true) + } err = dec.Decode(test.out) if test.decErr != "" { test.requireErr(err, test.decErr)