diff --git a/decode.go b/decode.go index ea645aa..b381c3b 100644 --- a/decode.go +++ b/decode.go @@ -24,6 +24,7 @@ const ( disallowUnknownFieldsFlag usePreallocateValues disableAllocLimitFlag + preferTextUnmarshalerForString ) type bufReader interface { @@ -184,6 +185,20 @@ func (d *Decoder) DisableAllocLimit(on bool) { } } +// PreferTextUnmarshalerForString makes the decoder prefer [encoding.TextUnmarshaler] +// over [encoding.BinaryUnmarshaler] when both are implemented, and source +// MessagePack data is a String (as opposed to Binary). +// +// If this option is not enabled, [encoding.BinaryUnmarshaler] will be preferred +// instead, regardless of MessagePack data type. +func (d *Decoder) PreferTextUnmarshalerForString(on bool) { + if on { + d.flags |= preferTextUnmarshalerForString + } else { + d.flags &= ^preferTextUnmarshalerForString + } +} + // Buffered returns a reader of the data remaining in the Decoder's buffer. // The reader is valid until the next call to Decode. func (d *Decoder) Buffered() io.Reader { diff --git a/decode_value.go b/decode_value.go index c44a674..e1b5b45 100644 --- a/decode_value.go +++ b/decode_value.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "reflect" + + "github.com/vmihailenco/msgpack/v5/msgpcode" ) var ( @@ -70,10 +72,16 @@ func _getDecoder(typ reflect.Type) decoderFunc { if typ.Implements(unmarshalerType) { return nilAwareDecoder(typ, unmarshalValue) } - if typ.Implements(binaryUnmarshalerType) { + + implementsBinaryUnmarshaler := typ.Implements(binaryUnmarshalerType) + implementsTextUnmarshaler := typ.Implements(textUnmarshalerType) + if implementsBinaryUnmarshaler && implementsTextUnmarshaler { + return nilAwareDecoder(typ, unmarshalBinaryOrTextValue) + } + if implementsBinaryUnmarshaler { return nilAwareDecoder(typ, unmarshalBinaryValue) } - if typ.Implements(textUnmarshalerType) { + if implementsTextUnmarshaler { return nilAwareDecoder(typ, unmarshalTextValue) } @@ -86,10 +94,15 @@ func _getDecoder(typ reflect.Type) decoderFunc { if ptr.Implements(unmarshalerType) { return addrDecoder(nilAwareDecoder(typ, unmarshalValue)) } - if ptr.Implements(binaryUnmarshalerType) { + implementsBinaryUnmarshaler := ptr.Implements(binaryUnmarshalerType) + implementsTextUnmarshaler := ptr.Implements(textUnmarshalerType) + if implementsBinaryUnmarshaler && implementsTextUnmarshaler { + return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryOrTextValue)) + } + if implementsBinaryUnmarshaler { return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue)) } - if ptr.Implements(textUnmarshalerType) { + if implementsTextUnmarshaler { return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue)) } } @@ -249,3 +262,21 @@ func unmarshalTextValue(d *Decoder, v reflect.Value) error { unmarshaler := v.Interface().(encoding.TextUnmarshaler) return unmarshaler.UnmarshalText(data) } + +func unmarshalBinaryOrTextValue(d *Decoder, v reflect.Value) error { + useText := false + if d.flags&preferTextUnmarshalerForString != 0 { + code, err := d.PeekCode() + if err != nil { + return err + } + if msgpcode.IsString(code) { + useText = true + } + } + if useText { + return unmarshalTextValue(d, v) + } else { + return unmarshalBinaryValue(d, v) + } +} diff --git a/types_test.go b/types_test.go index f8bfda1..384d6e5 100644 --- a/types_test.go +++ b/types_test.go @@ -2,6 +2,8 @@ package msgpack_test import ( "bytes" + "encoding" + "encoding/binary" "encoding/hex" "fmt" "math" @@ -427,6 +429,8 @@ type typeTest struct { wantnil bool wantzero bool wanted interface{} + + preferTextUnmarshalerForString bool } func (t typeTest) String() string { @@ -442,6 +446,36 @@ func (t *typeTest) requireErr(err error, s string) { } } +type binaryTextType uint32 + +// UnmarshalText implements encoding.TextUnmarshaler +func (v *binaryTextType) UnmarshalText(text []byte) error { + var b [4]byte + n, err := hex.Decode(b[:], text) + if err != nil { + return err + } + if n != 4 { + return fmt.Errorf("invalid length %d", n) + } + *v = binaryTextType(binary.BigEndian.Uint32(b[:])) + return nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler +func (v *binaryTextType) UnmarshalBinary(data []byte) error { + if n := len(data); n != 4 { + return fmt.Errorf("invalid length %d", n) + } + *v = binaryTextType(binary.BigEndian.Uint32(data)) + return nil +} + +var ( + _ encoding.TextUnmarshaler = new(binaryTextType) + _ encoding.BinaryUnmarshaler = new(binaryTextType) +) + var ( intSlice = make([]int, 0, 3) repoURL, _ = url.Parse("https://github.com/vmihailenco/msgpack") @@ -622,6 +656,36 @@ var ( }, {in: big.NewInt(123), out: new(big.Int)}, + + { + in: "deadbeef", + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + decErr: "invalid length 8", + + preferTextUnmarshalerForString: false, + }, + { + in: "deadbeef", + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + + preferTextUnmarshalerForString: true, + }, + { + in: []byte{0xde, 0xad, 0xbe, 0xef}, + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + + preferTextUnmarshalerForString: false, + }, + { + in: []byte{0xde, 0xad, 0xbe, 0xef}, + out: new(binaryTextType), + wanted: binaryTextType(0xdeadbeef), + + preferTextUnmarshalerForString: true, + }, } ) @@ -655,6 +719,9 @@ func TestTypes(t *testing.T) { } dec := msgpack.NewDecoder(&buf) + if test.preferTextUnmarshalerForString { + dec.PreferTextUnmarshalerForString(true) + } err = dec.Decode(test.out) if test.decErr != "" { test.requireErr(err, test.decErr)