Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
disallowUnknownFieldsFlag
usePreallocateValues
disableAllocLimitFlag
preferTextUnmarshalerForString
)

type bufReader interface {
Expand Down Expand Up @@ -184,6 +185,20 @@ func (d *Decoder) DisableAllocLimit(on bool) {
}
}

// PreferTextUnmarshalerForString makes the decoder prefer [encoding.TextUnmarshaler]
// over [encoding.BinaryUnmarshaler] when both are implemented, and source
// MessagePack data is a String (as opposed to Binary).
//
// If this option is not enabled, [encoding.BinaryUnmarshaler] will be preferred
// instead, regardless of MessagePack data type.
func (d *Decoder) PreferTextUnmarshalerForString(on bool) {
if on {
d.flags |= preferTextUnmarshalerForString
} else {
d.flags &= ^preferTextUnmarshalerForString
}
}

// Buffered returns a reader of the data remaining in the Decoder's buffer.
// The reader is valid until the next call to Decode.
func (d *Decoder) Buffered() io.Reader {
Expand Down
39 changes: 35 additions & 4 deletions decode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"reflect"

"github.com/vmihailenco/msgpack/v5/msgpcode"
)

var (
Expand Down Expand Up @@ -70,10 +72,16 @@ func _getDecoder(typ reflect.Type) decoderFunc {
if typ.Implements(unmarshalerType) {
return nilAwareDecoder(typ, unmarshalValue)
}
if typ.Implements(binaryUnmarshalerType) {

implementsBinaryUnmarshaler := typ.Implements(binaryUnmarshalerType)
implementsTextUnmarshaler := typ.Implements(textUnmarshalerType)
if implementsBinaryUnmarshaler && implementsTextUnmarshaler {
return nilAwareDecoder(typ, unmarshalBinaryOrTextValue)
}
if implementsBinaryUnmarshaler {
return nilAwareDecoder(typ, unmarshalBinaryValue)
}
if typ.Implements(textUnmarshalerType) {
if implementsTextUnmarshaler {
return nilAwareDecoder(typ, unmarshalTextValue)
}

Expand All @@ -86,10 +94,15 @@ func _getDecoder(typ reflect.Type) decoderFunc {
if ptr.Implements(unmarshalerType) {
return addrDecoder(nilAwareDecoder(typ, unmarshalValue))
}
if ptr.Implements(binaryUnmarshalerType) {
implementsBinaryUnmarshaler := ptr.Implements(binaryUnmarshalerType)
implementsTextUnmarshaler := ptr.Implements(textUnmarshalerType)
if implementsBinaryUnmarshaler && implementsTextUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryOrTextValue))
}
if implementsBinaryUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue))
}
if ptr.Implements(textUnmarshalerType) {
if implementsTextUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue))
}
}
Expand Down Expand Up @@ -249,3 +262,21 @@ func unmarshalTextValue(d *Decoder, v reflect.Value) error {
unmarshaler := v.Interface().(encoding.TextUnmarshaler)
return unmarshaler.UnmarshalText(data)
}

func unmarshalBinaryOrTextValue(d *Decoder, v reflect.Value) error {
useText := false
if d.flags&preferTextUnmarshalerForString != 0 {
code, err := d.PeekCode()
if err != nil {
return err
}
if msgpcode.IsString(code) {
useText = true
}
}
if useText {
return unmarshalTextValue(d, v)
} else {
return unmarshalBinaryValue(d, v)
}
}
67 changes: 67 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package msgpack_test

import (
"bytes"
"encoding"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
Expand Down Expand Up @@ -427,6 +429,8 @@ type typeTest struct {
wantnil bool
wantzero bool
wanted interface{}

preferTextUnmarshalerForString bool
}

func (t typeTest) String() string {
Expand All @@ -442,6 +446,36 @@ func (t *typeTest) requireErr(err error, s string) {
}
}

type binaryTextType uint32

// UnmarshalText implements encoding.TextUnmarshaler
func (v *binaryTextType) UnmarshalText(text []byte) error {
var b [4]byte
n, err := hex.Decode(b[:], text)
if err != nil {
return err
}
if n != 4 {
return fmt.Errorf("invalid length %d", n)
}
*v = binaryTextType(binary.BigEndian.Uint32(b[:]))
return nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler
func (v *binaryTextType) UnmarshalBinary(data []byte) error {
if n := len(data); n != 4 {
return fmt.Errorf("invalid length %d", n)
}
*v = binaryTextType(binary.BigEndian.Uint32(data))
return nil
}

var (
_ encoding.TextUnmarshaler = new(binaryTextType)
_ encoding.BinaryUnmarshaler = new(binaryTextType)
)

var (
intSlice = make([]int, 0, 3)
repoURL, _ = url.Parse("https://github.com/vmihailenco/msgpack")
Expand Down Expand Up @@ -622,6 +656,36 @@ var (
},

{in: big.NewInt(123), out: new(big.Int)},

{
in: "deadbeef",
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),
decErr: "invalid length 8",

preferTextUnmarshalerForString: false,
},
{
in: "deadbeef",
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: true,
},
{
in: []byte{0xde, 0xad, 0xbe, 0xef},
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: false,
},
{
in: []byte{0xde, 0xad, 0xbe, 0xef},
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: true,
},
}
)

Expand Down Expand Up @@ -655,6 +719,9 @@ func TestTypes(t *testing.T) {
}

dec := msgpack.NewDecoder(&buf)
if test.preferTextUnmarshalerForString {
dec.PreferTextUnmarshalerForString(true)
}
err = dec.Decode(test.out)
if test.decErr != "" {
test.requireErr(err, test.decErr)
Expand Down