internal/impl: change unmarshal func return to unmarshalOptions

The fast-path unmarshal funcs return the number of bytes consumed.

Change these functions to return an unmarshalOutput struct instead, to
make it easier to add to the results. This is groundwork for allowing
the fast-path unmarshaler to indicate when the unmarshaled message is
known to be initialized.

Change-Id: Ia8c44731a88f5be969a55cd98ea26282f412c7ae
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215720
Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
Damien Neil 2020-01-21 14:25:12 -08:00
parent 61781dd92f
commit f0831e87e2
8 changed files with 775 additions and 623 deletions

@ -96,16 +96,17 @@ func append{{.Name}}(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]b
}
// consume{{.Name}} wire decodes a {{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consume{{.Name}}(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
return out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoType}}
return n, nil
out.n = n
return out, nil
}
var coder{{.Name}} = pointerCoderFuncs{
@ -127,19 +128,20 @@ func append{{.Name}}ValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalO
}
// consume{{.Name}}ValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}ValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consume{{.Name}}ValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
return out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return 0, errInvalidUTF8{}
return out, errInvalidUTF8{}
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoType}}
return n, nil
out.n = n
return out, nil
}
var coder{{.Name}}ValidateUTF8 = pointerCoderFuncs{
@ -174,16 +176,17 @@ func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions
{{if .ToGoTypeNoZero}}
// consume{{.Name}}NoZero wire decodes a {{.GoType}} pointer as a {{.Name}}.
// The zero value is not decoded.
func consume{{.Name}}NoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consume{{.Name}}NoZero(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
return out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
return n, nil
out.n = n
return out, nil
}
{{end}}
@ -211,19 +214,20 @@ func append{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ ma
{{if .ToGoTypeNoZero}}
// consume{{.Name}}NoZeroValidateUTF8 wire decodes a {{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consume{{.Name}}NoZeroValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
return out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return 0, errInvalidUTF8{}
return out, errInvalidUTF8{}
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoTypeNoZero}}
return n, nil
out.n = n
return out, nil
}
{{end}}
@ -254,20 +258,21 @@ func append{{.Name}}Ptr(b []byte, p pointer, wiretag uint64, _ marshalOptions) (
}
// consume{{.Name}}Ptr wire decodes a *{{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}Ptr(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consume{{.Name}}Ptr(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
return out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
vp := p.{{.GoType.PointerMethod}}Ptr()
if *vp == nil {
*vp = new({{.GoType}})
}
**vp = {{.ToGoType}}
return n, nil
out.n = n
return out, nil
}
var coder{{.Name}}Ptr = pointerCoderFuncs{
@ -301,36 +306,38 @@ func append{{.Name}}Slice(b []byte, p pointer, wiretag uint64, _ marshalOptions)
}
// consume{{.Name}}Slice wire decodes a []{{.GoType}} pointer as a repeated {{.Name}}.
func consume{{.Name}}Slice(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consume{{.Name}}Slice(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
sp := p.{{.GoType.PointerMethod}}Slice()
{{- if .WireType.Packable}}
if wtyp == wire.BytesType {
s := *sp
b, n = wire.ConsumeBytes(b)
b, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
for len(b) > 0 {
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
s = append(s, {{.ToGoType}})
b = b[n:]
}
*sp = s
return n, nil
out.n = n
return out, nil
}
{{- end}}
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
return out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
*sp = append(*sp, {{.ToGoType}})
return n, nil
out.n = n
return out, nil
}
var coder{{.Name}}Slice = pointerCoderFuncs{
@ -354,20 +361,21 @@ func append{{.Name}}SliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ mar
}
// consume{{.Name}}SliceValidateUTF8 wire decodes a []{{.GoType}} pointer as a repeated {{.Name}}.
func consume{{.Name}}SliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consume{{.Name}}SliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
sp := p.{{.GoType.PointerMethod}}Slice()
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
return out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if !utf8.Valid{{if eq .Name "String"}}String{{end}}(v) {
return 0, errInvalidUTF8{}
return out, errInvalidUTF8{}
}
*sp = append(*sp, {{.ToGoType}})
return n, nil
out.n = n
return out, nil
}
var coder{{.Name}}SliceValidateUTF8 = pointerCoderFuncs{
@ -440,15 +448,16 @@ func append{{.Name}}Value(b []byte, v protoreflect.Value, wiretag uint64, _ mars
}
// consume{{.Name}}Value decodes a {{.GoType}} value as a {{.Name}}.
func consume{{.Name}}Value(b []byte, _ protoreflect.Value, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (protoreflect.Value, int, error) {
func consume{{.Name}}Value(b []byte, _ protoreflect.Value, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (_ protoreflect.Value, out unmarshalOutput, err error) {
if wtyp != {{.WireType.Expr}} {
return protoreflect.Value{}, 0, errUnknown
return protoreflect.Value{}, out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return protoreflect.Value{}, 0, wire.ParseError(n)
return protoreflect.Value{}, out, wire.ParseError(n)
}
return {{.ToValue}}, n, nil
out.n = n
return {{.ToValue}}, out, nil
}
var coder{{.Name}}Value = valueCoderFuncs{
@ -469,18 +478,19 @@ func append{{.Name}}ValueValidateUTF8(b []byte, v protoreflect.Value, wiretag ui
}
// consume{{.Name}}ValueValidateUTF8 decodes a {{.GoType}} value as a {{.Name}}.
func consume{{.Name}}ValueValidateUTF8(b []byte, _ protoreflect.Value, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (protoreflect.Value, int, error) {
func consume{{.Name}}ValueValidateUTF8(b []byte, _ protoreflect.Value, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (_ protoreflect.Value, out unmarshalOutput, err error) {
if wtyp != {{.WireType.Expr}} {
return protoreflect.Value{}, 0, errUnknown
return protoreflect.Value{}, out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return protoreflect.Value{}, 0, wire.ParseError(n)
return protoreflect.Value{}, out, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return protoreflect.Value{}, 0, errInvalidUTF8{}
return protoreflect.Value{}, out, errInvalidUTF8{}
}
return {{.ToValue}}, n, nil
out.n = n
return {{.ToValue}}, out, nil
}
var coder{{.Name}}ValueValidateUTF8 = valueCoderFuncs{
@ -516,34 +526,36 @@ func append{{.Name}}SliceValue(b []byte, listv protoreflect.Value, wiretag uint6
}
// consume{{.Name}}SliceValue wire decodes a []{{.GoType}} value as a repeated {{.Name}}.
func consume{{.Name}}SliceValue(b []byte, listv protoreflect.Value, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (_ protoreflect.Value, n int, err error) {
func consume{{.Name}}SliceValue(b []byte, listv protoreflect.Value, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (_ protoreflect.Value, out unmarshalOutput, err error) {
list := listv.List()
{{- if .WireType.Packable}}
if wtyp == wire.BytesType {
b, n = wire.ConsumeBytes(b)
b, n := wire.ConsumeBytes(b)
if n < 0 {
return protoreflect.Value{}, 0, wire.ParseError(n)
return protoreflect.Value{}, out, wire.ParseError(n)
}
for len(b) > 0 {
v, n := {{template "Consume" .}}
if n < 0 {
return protoreflect.Value{}, 0, wire.ParseError(n)
return protoreflect.Value{}, out, wire.ParseError(n)
}
list.Append({{.ToValue}})
b = b[n:]
}
return listv, n, nil
out.n = n
return listv, out, nil
}
{{- end}}
if wtyp != {{.WireType.Expr}} {
return protoreflect.Value{}, 0, errUnknown
return protoreflect.Value{}, out, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return protoreflect.Value{}, 0, wire.ParseError(n)
return protoreflect.Value{}, out, wire.ParseError(n)
}
list.Append({{.ToValue}})
return listv, n, nil
out.n = n
return listv, out, nil
}
var coder{{.Name}}SliceValue = valueCoderFuncs{

@ -52,7 +52,7 @@ func (mi *MessageInfo) initOneofFieldCoders(od pref.OneofDescriptor, si structIn
if funcs.isInit != nil {
needIsInit = true
}
cf.funcs.unmarshal = func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
cf.funcs.unmarshal = func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
var vw reflect.Value // pointer to wrapper type
vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
@ -60,12 +60,12 @@ func (mi *MessageInfo) initOneofFieldCoders(od pref.OneofDescriptor, si structIn
} else {
vw = reflect.New(ot)
}
n, err := funcs.unmarshal(b, pointerOfValue(vw).Apply(zeroOffset), wtyp, opts)
out, err := funcs.unmarshal(b, pointerOfValue(vw).Apply(zeroOffset), wtyp, opts)
if err != nil {
return 0, err
return out, err
}
vi.Set(vw)
return n, nil
return out, nil
}
}
getInfo := func(p pointer) (pointer, *oneofFieldInfo) {
@ -139,13 +139,13 @@ func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {
}
return appendMessage(b, m, wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
fs := p.WeakFields()
m, ok := fs.get(num)
if !ok {
lazyInit()
if messageType == nil {
return 0, errUnknown
return unmarshalOutput{}, errUnknown
}
m = messageType.New().Interface()
fs.set(num, m)
@ -171,7 +171,7 @@ func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCode
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageInfo(b, p, wiretag, mi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeMessageInfo(b, p, mi, wtyp, opts)
},
}
@ -191,7 +191,7 @@ func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCode
m := asMessage(p.AsValueOf(ft).Elem())
return appendMessage(b, m, wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
@ -216,21 +216,22 @@ func appendMessageInfo(b []byte, p pointer, wiretag uint64, mi *MessageInfo, opt
return mi.marshalAppendPointer(b, p.Elem(), opts)
}
func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(mi.GoReflectType.Elem())))
}
if _, err := mi.unmarshalPointer(v, p.Elem(), 0, opts); err != nil {
return 0, err
return out, err
}
return n, nil
out.n = n
return out, nil
}
func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
@ -243,18 +244,19 @@ func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOption
return opts.Options().MarshalAppend(b, m)
}
func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOptions) (int, error) {
func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if err := opts.Options().Unmarshal(v, m); err != nil {
return 0, err
return out, err
}
return n, nil
out.n = n
return out, nil
}
func sizeMessageValue(v pref.Value, tagsize int, opts marshalOptions) int {
@ -267,10 +269,10 @@ func appendMessageValue(b []byte, v pref.Value, wiretag uint64, opts marshalOpti
return appendMessage(b, m, wiretag, opts)
}
func consumeMessageValue(b []byte, v pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
func consumeMessageValue(b []byte, v pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, unmarshalOutput, error) {
m := v.Message().Interface()
n, err := consumeMessage(b, m, wtyp, opts)
return v, n, err
out, err := consumeMessage(b, m, wtyp, opts)
return v, out, err
}
func isInitMessageValue(v pref.Value) error {
@ -295,10 +297,10 @@ func appendGroupValue(b []byte, v pref.Value, wiretag uint64, opts marshalOption
return appendGroup(b, m, wiretag, opts)
}
func consumeGroupValue(b []byte, v pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
func consumeGroupValue(b []byte, v pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, unmarshalOutput, error) {
m := v.Message().Interface()
n, err := consumeGroup(b, m, num, wtyp, opts)
return v, n, err
out, err := consumeGroup(b, m, num, wtyp, opts)
return v, out, err
}
var coderGroupValue = valueCoderFuncs{
@ -318,7 +320,7 @@ func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderF
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupType(b, p, wiretag, mi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeGroupType(b, p, mi, num, wtyp, opts)
},
}
@ -338,7 +340,7 @@ func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderF
m := asMessage(p.AsValueOf(ft).Elem())
return appendGroup(b, m, wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
@ -364,9 +366,9 @@ func appendGroupType(b []byte, p pointer, wiretag uint64, mi *MessageInfo, opts
return b, err
}
func consumeGroupType(b []byte, p pointer, mi *MessageInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
func consumeGroupType(b []byte, p pointer, mi *MessageInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
return out, errUnknown
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(mi.GoReflectType.Elem())))
@ -385,15 +387,16 @@ func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions)
return b, err
}
func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
return out, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
return n, opts.Options().Unmarshal(b, m)
out.n = n
return out, opts.Options().Unmarshal(b, m)
}
func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@ -405,7 +408,7 @@ func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointe
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSliceInfo(b, p, wiretag, mi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeMessageSliceInfo(b, p, mi, wtyp, opts)
},
}
@ -423,7 +426,7 @@ func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointe
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSlice(b, p, wiretag, ft, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeMessageSlice(b, p, ft, wtyp, opts)
},
isInit: func(p pointer) error {
@ -456,21 +459,22 @@ func appendMessageSliceInfo(b []byte, p pointer, wiretag uint64, mi *MessageInfo
return b, nil
}
func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
m := reflect.New(mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m)
if _, err := mi.unmarshalPointer(v, mp, 0, opts); err != nil {
return 0, err
return out, err
}
p.AppendPointerSlice(mp)
return n, nil
out.n = n
return out, nil
}
func isInitMessageSliceInfo(p pointer, mi *MessageInfo) error {
@ -509,20 +513,21 @@ func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type
return b, nil
}
func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Type, opts unmarshalOptions) (int, error) {
func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(v, asMessage(mp)); err != nil {
return 0, err
return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
return n, nil
out.n = n
return out, nil
}
func isInitMessageSlice(p pointer, goType reflect.Type) error {
@ -565,21 +570,22 @@ func appendMessageSliceValue(b []byte, listv pref.Value, wiretag uint64, opts ma
return b, nil
}
func consumeMessageSliceValue(b []byte, listv pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
func consumeMessageSliceValue(b []byte, listv pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (_ pref.Value, out unmarshalOutput, err error) {
list := listv.List()
if wtyp != wire.BytesType {
return pref.Value{}, 0, errUnknown
return pref.Value{}, out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return pref.Value{}, 0, wire.ParseError(n)
return pref.Value{}, out, wire.ParseError(n)
}
m := list.NewElement()
if err := opts.Options().Unmarshal(v, m.Message().Interface()); err != nil {
return pref.Value{}, 0, err
return pref.Value{}, out, err
}
list.Append(m)
return listv, n, nil
out.n = n
return listv, out, nil
}
func isInitMessageSliceValue(listv pref.Value) error {
@ -626,21 +632,22 @@ func appendGroupSliceValue(b []byte, listv pref.Value, wiretag uint64, opts mars
return b, nil
}
func consumeGroupSliceValue(b []byte, listv pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
func consumeGroupSliceValue(b []byte, listv pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (_ pref.Value, out unmarshalOutput, err error) {
list := listv.List()
if wtyp != wire.StartGroupType {
return pref.Value{}, 0, errUnknown
return pref.Value{}, out, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
return pref.Value{}, 0, wire.ParseError(n)
return pref.Value{}, out, wire.ParseError(n)
}
m := list.NewElement()
if err := opts.Options().Unmarshal(b, m.Message().Interface()); err != nil {
return pref.Value{}, 0, err
return pref.Value{}, out, err
}
list.Append(m)
return listv, n, nil
out.n = n
return listv, out, nil
}
var coderGroupSliceValue = valueCoderFuncs{
@ -660,7 +667,7 @@ func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerC
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSliceInfo(b, p, wiretag, mi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeGroupSliceInfo(b, p, num, wtyp, mi, opts)
},
}
@ -678,7 +685,7 @@ func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerC
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSlice(b, p, wiretag, ft, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeGroupSlice(b, p, num, wtyp, ft, opts)
},
isInit: func(p pointer) error {
@ -712,20 +719,21 @@ func appendGroupSlice(b []byte, p pointer, wiretag uint64, messageType reflect.T
return b, nil
}
func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goType reflect.Type, opts unmarshalOptions) (int, error) {
func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goType reflect.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
return out, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(b, asMessage(mp)); err != nil {
return 0, err
return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
return n, nil
out.n = n
return out, nil
}
func sizeGroupSliceInfo(p pointer, mi *MessageInfo, tagsize int, opts marshalOptions) int {
@ -751,18 +759,18 @@ func appendGroupSliceInfo(b []byte, p pointer, wiretag uint64, mi *MessageInfo,
return b, nil
}
func consumeGroupSliceInfo(b []byte, p pointer, num wire.Number, wtyp wire.Type, mi *MessageInfo, opts unmarshalOptions) (int, error) {
func consumeGroupSliceInfo(b []byte, p pointer, num wire.Number, wtyp wire.Type, mi *MessageInfo, opts unmarshalOptions) (unmarshalOutput, error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
return unmarshalOutput{}, errUnknown
}
m := reflect.New(mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m)
n, err := mi.unmarshalPointer(b, mp, num, opts)
out, err := mi.unmarshalPointer(b, mp, num, opts)
if err != nil {
return 0, err
return out, err
}
p.AppendPointerSlice(mp)
return n, nil
return out, nil
}
func asMessage(v reflect.Value) pref.ProtoMessage {

File diff suppressed because it is too large Load Diff

@ -56,7 +56,7 @@ func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointer
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft)
if mp.Elem().IsNil() {
mp.Elem().Set(reflect.MakeMap(mapi.goType))
@ -104,13 +104,13 @@ func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions
return n
}
func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
return out, errUnknown
}
b, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
var (
key = mapi.keyZero
@ -119,50 +119,55 @@ func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opt
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if num > wire.MaxValidNumber {
return 0, errors.New("invalid field number")
return out, errors.New("invalid field number")
}
b = b[n:]
err := errUnknown
switch num {
case 1:
var v pref.Value
v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
var o unmarshalOutput
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
n = o.n
case 2:
var v pref.Value
v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
var o unmarshalOutput
v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
if err != nil {
break
}
val = v
n = o.n
}
if err == errUnknown {
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
} else if err != nil {
return 0, err
return out, err
}
b = b[n:]
}
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
return n, nil
out.n = n
return out, nil
}
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
return out, errUnknown
}
b, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
var (
key = mapi.keyZero
@ -171,21 +176,23 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *map
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if num > wire.MaxValidNumber {
return 0, errors.New("invalid field number")
return out, errors.New("invalid field number")
}
b = b[n:]
err := errUnknown
switch num {
case 1:
var v pref.Value
v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
var o unmarshalOutput
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
n = o.n
case 2:
if wtyp != wire.BytesType {
break
@ -193,22 +200,23 @@ func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *map
var v []byte
v, n = wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
_, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
}
if err == errUnknown {
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
} else if err != nil {
return 0, err
return out, err
}
b = b[n:]
}
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
return n, nil
out.n = n
return out, nil
}
func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {

@ -90,9 +90,9 @@ func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts ma
return b, nil
}
func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts unmarshalOptions) (int, error) {
func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts unmarshalOptions) (out unmarshalOutput, err error) {
if !flags.ProtoLegacy {
return 0, errors.New("no support for message_set_wire_format")
return out, errors.New("no support for message_set_wire_format")
}
ep := p.Apply(mi.extensionOffset).Extensions()
@ -101,7 +101,7 @@ func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts unmarshalOpt
}
ext := *ep
unknown := p.Apply(mi.unknownOffset).Bytes()
err := messageset.Unmarshal(b, true, func(num wire.Number, v []byte) error {
err = messageset.Unmarshal(b, true, func(num wire.Number, v []byte) error {
_, err := mi.unmarshalExtension(v, num, wire.BytesType, ext, opts)
if err == errUnknown {
*unknown = wire.AppendTag(*unknown, num, wire.BytesType)
@ -110,5 +110,6 @@ func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, opts unmarshalOpt
}
return err
})
return len(b), err
out.n = len(b)
return out, err
}

@ -24,16 +24,17 @@ func appendEnum(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byt
return b, nil
}
func consumeEnum(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
func consumeEnum(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.VarintType {
return 0, errUnknown
return out, errUnknown
}
v, n := wire.ConsumeVarint(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
p.v.Elem().SetInt(int64(v))
return n, nil
out.n = n
return out, nil
}
var coderEnum = pointerCoderFuncs{
@ -70,9 +71,9 @@ func appendEnumPtr(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]
return appendEnum(b, pointer{p.v.Elem()}, wiretag, opts)
}
func consumeEnumPtr(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (n int, err error) {
func consumeEnumPtr(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.VarintType {
return 0, errUnknown
return out, errUnknown
}
if p.v.Elem().IsNil() {
p.v.Elem().Set(reflect.New(p.v.Elem().Type().Elem()))
@ -103,36 +104,38 @@ func appendEnumSlice(b []byte, p pointer, wiretag uint64, opts marshalOptions) (
return b, nil
}
func consumeEnumSlice(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (n int, err error) {
func consumeEnumSlice(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
s := p.v.Elem()
if wtyp == wire.BytesType {
b, n = wire.ConsumeBytes(b)
b, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
for len(b) > 0 {
v, n := wire.ConsumeVarint(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
rv := reflect.New(s.Type().Elem()).Elem()
rv.SetInt(int64(v))
s.Set(reflect.Append(s, rv))
b = b[n:]
}
return n, nil
out.n = n
return out, nil
}
if wtyp != wire.VarintType {
return 0, errUnknown
return out, errUnknown
}
v, n := wire.ConsumeVarint(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
rv := reflect.New(s.Type().Elem()).Elem()
rv.SetInt(int64(v))
s.Set(reflect.Append(s, rv))
return n, nil
out.n = n
return out, nil
}
var coderEnumSlice = pointerCoderFuncs{

@ -17,7 +17,7 @@ import (
type pointerCoderFuncs struct {
size func(p pointer, tagsize int, opts marshalOptions) int
marshal func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error)
unmarshal func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error)
unmarshal func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error)
isInit func(p pointer) error
}
@ -25,7 +25,7 @@ type pointerCoderFuncs struct {
type valueCoderFuncs struct {
size func(v pref.Value, tagsize int, opts marshalOptions) int
marshal func(b []byte, v pref.Value, wiretag uint64, opts marshalOptions) ([]byte, error)
unmarshal func(b []byte, v pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error)
unmarshal func(b []byte, v pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, unmarshalOutput, error)
isInit func(v pref.Value) error
}

@ -57,6 +57,10 @@ func (o unmarshalOptions) Options() proto.UnmarshalOptions {
func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&unmarshalDiscardUnknown != 0 }
func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver }
type unmarshalOutput struct {
n int // number of bytes consumed
}
// unmarshal is protoreflect.Methods.Unmarshal.
func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
var p pointer
@ -77,7 +81,7 @@ func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput) (pifa
// This is a sentinel error which should never be visible to the user.
var errUnknown = errors.New("unknown")
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Number, opts unmarshalOptions) (int, error) {
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
mi.init()
if flags.ProtoLegacy && mi.isMessageSet {
return unmarshalMessageSet(mi, b, p, opts)
@ -89,18 +93,19 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
// TODO: inline 1 and 2 byte variants?
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if num > wire.MaxValidNumber {
return 0, errors.New("invalid field number")
return out, errors.New("invalid field number")
}
b = b[n:]
if wtyp == wire.EndGroupType {
if num != groupTag {
return 0, errors.New("mismatching end group marker")
return out, errors.New("mismatching end group marker")
}
return start - len(b), nil
out.n = start - len(b)
return out, nil
}
var f *coderFieldInfo
@ -115,7 +120,9 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
if f.funcs.unmarshal == nil {
break
}
n, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
var o unmarshalOutput
o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
n = o.n
default:
// Possible extension.
if exts == nil && mi.extensionOffset.IsValid() {
@ -127,15 +134,17 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
if exts == nil {
break
}
n, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
var o unmarshalOutput
o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
n = o.n
}
if err != nil {
if err != errUnknown {
return 0, err
return out, err
}
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
return out, wire.ParseError(n)
}
if mi.unknownOffset.IsValid() {
u := p.Apply(mi.unknownOffset).Bytes()
@ -146,12 +155,13 @@ func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Numbe
b = b[n:]
}
if groupTag != 0 {
return 0, errors.New("missing end group marker")
return out, errors.New("missing end group marker")
}
return start, nil
out.n = start
return out, nil
}
func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (n int, err error) {
func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
x := exts[int32(num)]
xt := x.Type()
if xt == nil {
@ -159,14 +169,14 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.T
xt, err = opts.Resolver().FindExtensionByNumber(mi.Desc.FullName(), num)
if err != nil {
if err == preg.NotFound {
return 0, errUnknown
return out, errUnknown
}
return 0, err
return out, err
}
}
xi := getExtensionFieldInfo(xt)
if xi.funcs.unmarshal == nil {
return 0, errUnknown
return out, errUnknown
}
ival := x.Value()
if !ival.IsValid() && xi.unmarshalNeedsValue {
@ -175,11 +185,11 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.T
// concrete type.
ival = xt.New()
}
v, n, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
if err != nil {
return 0, err
return out, err
}
x.Set(xt, v)
exts[int32(num)] = x
return n, nil
return out, nil
}