internal/impl: add fast-path unmarshal

Benchmarks run with:
  go test ./benchmarks/ -bench=Wire  -benchtime=500ms -benchmem -count=8

Fast-path vs. parent commit:

  name                                      old time/op    new time/op    delta
  Wire/Unmarshal/google_message1_proto2-12    1.35µs ± 2%    0.45µs ± 4%  -67.01%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12    1.07µs ± 1%    0.31µs ± 1%  -71.04%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            691µs ± 2%     188µs ± 2%  -72.78%  (p=0.000 n=7+8)

  name                                      old allocs/op  new allocs/op  delta
  Wire/Unmarshal/google_message1_proto2-12      60.0 ± 0%      25.0 ± 0%  -58.33%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12      42.0 ± 0%       7.0 ± 0%  -83.33%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            28.6k ± 0%      8.5k ± 0%  -70.34%  (p=0.000 n=8+8)

Fast-path vs. -v1:

  name                                      old time/op    new time/op    delta
  Wire/Unmarshal/google_message1_proto2-12     702ns ± 1%     445ns ± 4%   -36.58%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12     604ns ± 1%     311ns ± 1%   -48.54%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            179µs ± 3%     188µs ± 2%    +5.30%  (p=0.000 n=7+8)

  name                                      old allocs/op  new allocs/op  delta
  Wire/Unmarshal/google_message1_proto2-12      26.0 ± 0%      25.0 ± 0%    -3.85%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12      8.00 ± 0%      7.00 ± 0%   -12.50%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            8.49k ± 0%     8.49k ± 0%    -0.01%  (p=0.000 n=8+8)

Change-Id: I6247ac3fd66a63d9acb902cbd192094ee3d151c3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185147
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-06-27 10:54:42 -07:00
parent 3d0706ac2e
commit e91877de26
15 changed files with 2694 additions and 360 deletions

View File

@ -50,6 +50,14 @@ b = wire.Append{{.WireType}}(b, {{.FromGoType}})
{{- end -}}
{{- end -}}
{{- define "Consume" -}}
{{- if eq .Name "String" -}}
wire.ConsumeString(b)
{{- else -}}
wire.Consume{{.WireType}}(b)
{{- end -}}
{{- end -}}
{{- range .}}
{{- if .FromGoType }}
// size{{.Name}} returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}.
@ -68,9 +76,23 @@ func append{{.Name}}(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]b
return b, nil
}
// consume{{.Name}} wire decodes a {{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
*p.{{.GoType.PointerMethod}}() = {{.ToGoType}}
return n, nil
}
var coder{{.Name}} = pointerCoderFuncs{
size: size{{.Name}},
marshal: append{{.Name}},
size: size{{.Name}},
marshal: append{{.Name}},
unmarshal: consume{{.Name}},
}
// size{{.Name}} returns the size of wire encoding a {{.GoType}} pointer as a {{.Name}}.
@ -96,8 +118,9 @@ func append{{.Name}}NoZero(b []byte, p pointer, wiretag uint64, _ marshalOptions
}
var coder{{.Name}}NoZero = pointerCoderFuncs{
size: size{{.Name}}NoZero,
marshal: append{{.Name}}NoZero,
size: size{{.Name}}NoZero,
marshal: append{{.Name}}NoZero,
unmarshal: consume{{.Name}},
}
{{- if not .NoPointer}}
@ -110,7 +133,7 @@ func size{{.Name}}Ptr(p pointer, tagsize int, _ marshalOptions) (size int) {
return tagsize + {{template "Size" .}}
}
// append{{.Name}} wire encodes a *{{.GoType}} pointer as a {{.Name}}.
// append{{.Name}}Ptr wire encodes a *{{.GoType}} pointer as a {{.Name}}.
// It panics if the pointer is nil.
func append{{.Name}}Ptr(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
v := **p.{{.GoType.PointerMethod}}Ptr()
@ -119,9 +142,27 @@ func append{{.Name}}Ptr(b []byte, p pointer, wiretag uint64, _ marshalOptions) (
return b, nil
}
// consume{{.Name}}Ptr wire decodes a *{{.GoType}} pointer as a {{.Name}}.
func consume{{.Name}}Ptr(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
vp := p.{{.GoType.PointerMethod}}Ptr()
if *vp == nil {
*vp = new({{.GoType}})
}
**vp = {{.ToGoType}}
return n, nil
}
var coder{{.Name}}Ptr = pointerCoderFuncs{
size: size{{.Name}}Ptr,
marshal: append{{.Name}}Ptr,
size: size{{.Name}}Ptr,
marshal: append{{.Name}}Ptr,
unmarshal: consume{{.Name}}Ptr,
}
{{end}}
@ -148,9 +189,43 @@ func append{{.Name}}Slice(b []byte, p pointer, wiretag uint64, _ marshalOptions)
return b, nil
}
// consume{{.Name}}Slice wire decodes a []{{.GoType}} pointer as a repeated {{.Name}}.
func consume{{.Name}}Slice(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
sp := p.{{.GoType.PointerMethod}}Slice()
{{- if .WireType.Packable}}
if wtyp == wire.BytesType {
s := *sp
b, n = wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
for len(b) > 0 {
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
s = append(s, {{.ToGoType}})
b = b[n:]
}
*sp = s
return n, nil
}
{{- end}}
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return 0, wire.ParseError(n)
}
*sp = append(*sp, {{.ToGoType}})
return n, nil
}
var coder{{.Name}}Slice = pointerCoderFuncs{
size: size{{.Name}}Slice,
marshal: append{{.Name}}Slice,
size: size{{.Name}}Slice,
marshal: append{{.Name}}Slice,
unmarshal: consume{{.Name}}Slice,
}
{{if or (eq .WireType "Varint") (eq .WireType "Fixed32") (eq .WireType "Fixed64")}}
@ -194,8 +269,9 @@ func append{{.Name}}PackedSlice(b []byte, p pointer, wiretag uint64, _ marshalOp
}
var coder{{.Name}}PackedSlice = pointerCoderFuncs{
size: size{{.Name}}PackedSlice,
marshal: append{{.Name}}PackedSlice,
size: size{{.Name}}PackedSlice,
marshal: append{{.Name}}PackedSlice,
unmarshal: consume{{.Name}}Slice,
}
{{end}}
@ -215,9 +291,22 @@ func append{{.Name}}Iface(b []byte, ival interface{}, wiretag uint64, _ marshalO
return b, nil
}
// consume{{.Name}}Iface decodes a {{.GoType}} value as a {{.Name}}.
func consume{{.Name}}Iface(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
if wtyp != {{.WireType.Expr}} {
return nil, 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return nil, 0, wire.ParseError(n)
}
return {{.ToGoType}}, n, nil
}
var coder{{.Name}}Iface = ifaceCoderFuncs{
size: size{{.Name}}Iface,
marshal: append{{.Name}}Iface,
unmarshal: consume{{.Name}}Iface,
}
// size{{.Name}}SliceIface returns the size of wire encoding a []{{.GoType}} value as a repeated {{.Name}}.
@ -243,9 +332,44 @@ func append{{.Name}}SliceIface(b []byte, ival interface{}, wiretag uint64, _ mar
return b, nil
}
// consume{{.Name}}SliceIface wire decodes a []{{.GoType}} value as a repeated {{.Name}}.
func consume{{.Name}}SliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (_ interface{}, n int, err error) {
sp := ival.(*[]{{.GoType}})
{{- if .WireType.Packable}}
if wtyp == wire.BytesType {
s := *sp
b, n = wire.ConsumeBytes(b)
if n < 0 {
return nil, 0, wire.ParseError(n)
}
for len(b) > 0 {
v, n := {{template "Consume" .}}
if n < 0 {
return nil, 0, wire.ParseError(n)
}
s = append(s, {{.ToGoType}})
b = b[n:]
}
*sp = s
return ival, n, nil
}
{{- end}}
if wtyp != {{.WireType.Expr}} {
return nil, 0, errUnknown
}
v, n := {{template "Consume" .}}
if n < 0 {
return nil, 0, wire.ParseError(n)
}
*sp = append(*sp, {{.ToGoType}})
return ival, n, nil
}
var coder{{.Name}}SliceIface = ifaceCoderFuncs{
size: size{{.Name}}SliceIface,
marshal: append{{.Name}}SliceIface,
size: size{{.Name}}SliceIface,
marshal: append{{.Name}}SliceIface,
unmarshal: consume{{.Name}}SliceIface,
}
{{end -}}

View File

@ -86,6 +86,7 @@ type ProtoKind struct {
// Conversions to/from generated structures.
GoType GoType
ToGoType Expr
FromGoType Expr
NoPointer bool
}
@ -101,6 +102,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "wire.DecodeBool(v)",
FromValue: "wire.EncodeBool(v.Bool())",
GoType: GoBool,
ToGoType: "wire.DecodeBool(v)",
FromGoType: "wire.EncodeBool(v)",
},
{
@ -115,6 +117,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "int32(v)",
FromValue: "uint64(int32(v.Int()))",
GoType: GoInt32,
ToGoType: "int32(v)",
FromGoType: "uint64(v)",
},
{
@ -123,6 +126,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "int32(wire.DecodeZigZag(v & math.MaxUint32))",
FromValue: "wire.EncodeZigZag(int64(int32(v.Int())))",
GoType: GoInt32,
ToGoType: "int32(wire.DecodeZigZag(v & math.MaxUint32))",
FromGoType: "wire.EncodeZigZag(int64(v))",
},
{
@ -131,6 +135,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "uint32(v)",
FromValue: "uint64(uint32(v.Uint()))",
GoType: GoUint32,
ToGoType: "uint32(v)",
FromGoType: "uint64(v)",
},
{
@ -139,6 +144,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "int64(v)",
FromValue: "uint64(v.Int())",
GoType: GoInt64,
ToGoType: "int64(v)",
FromGoType: "uint64(v)",
},
{
@ -147,6 +153,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "wire.DecodeZigZag(v)",
FromValue: "wire.EncodeZigZag(v.Int())",
GoType: GoInt64,
ToGoType: "wire.DecodeZigZag(v)",
FromGoType: "wire.EncodeZigZag(v)",
},
{
@ -155,6 +162,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "v",
FromValue: "v.Uint()",
GoType: GoUint64,
ToGoType: "v",
FromGoType: "v",
},
{
@ -163,6 +171,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "int32(v)",
FromValue: "uint32(v.Int())",
GoType: GoInt32,
ToGoType: "int32(v)",
FromGoType: "uint32(v)",
},
{
@ -171,6 +180,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "uint32(v)",
FromValue: "uint32(v.Uint())",
GoType: GoUint32,
ToGoType: "v",
FromGoType: "v",
},
{
@ -179,6 +189,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "math.Float32frombits(uint32(v))",
FromValue: "math.Float32bits(float32(v.Float()))",
GoType: GoFloat32,
ToGoType: "math.Float32frombits(v)",
FromGoType: "math.Float32bits(v)",
},
{
@ -187,6 +198,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "int64(v)",
FromValue: "uint64(v.Int())",
GoType: GoInt64,
ToGoType: "int64(v)",
FromGoType: "uint64(v)",
},
{
@ -195,6 +207,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "v",
FromValue: "v.Uint()",
GoType: GoUint64,
ToGoType: "v",
FromGoType: "v",
},
{
@ -203,6 +216,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "math.Float64frombits(v)",
FromValue: "math.Float64bits(v.Float())",
GoType: GoFloat64,
ToGoType: "math.Float64frombits(v)",
FromGoType: "math.Float64bits(v)",
},
{
@ -211,6 +225,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "string(v)",
FromValue: "v.String()",
GoType: GoString,
ToGoType: "v",
FromGoType: "v",
},
{
@ -219,6 +234,7 @@ var ProtoKinds = []ProtoKind{
ToValue: "append(([]byte)(nil), v...)",
FromValue: "v.Bytes()",
GoType: GoBytes,
ToGoType: "append(([]byte)(nil), v...)",
FromGoType: "v",
NoPointer: true,
},

View File

@ -425,11 +425,6 @@ func AppendBytes(b []byte, v []byte) []byte {
return append(AppendVarint(b, uint64(len(v))), v...)
}
// AppendString appends v to b as a length-prefixed bytes value.
func AppendString(b []byte, v string) []byte {
return append(AppendVarint(b, uint64(len(v))), v...)
}
// ConsumeBytes parses b as a length-prefixed bytes value, reporting its length.
// This returns a negative length upon an error (see ParseError).
func ConsumeBytes(b []byte) (v []byte, n int) {
@ -449,6 +444,18 @@ func SizeBytes(n int) int {
return SizeVarint(uint64(n)) + n
}
// AppendString appends v to b as a length-prefixed bytes value.
func AppendString(b []byte, v string) []byte {
return append(AppendVarint(b, uint64(len(v))), v...)
}
// ConsumeString parses b as a length-prefixed bytes value, reporting its length.
// This returns a negative length upon an error (see ParseError).
func ConsumeString(b []byte) (v string, n int) {
bb, n := ConsumeBytes(b)
return string(bb), n
}
// AppendGroup appends v to b as group value, with a trailing end group marker.
// The value v must not contain the end marker.
func AppendGroup(b []byte, num Number, v []byte) []byte {

View File

@ -13,9 +13,10 @@ import (
)
type extensionFieldInfo struct {
wiretag uint64
tagsize int
funcs ifaceCoderFuncs
wiretag uint64
tagsize int
unmarshalNeedsValue bool
funcs ifaceCoderFuncs
}
func (mi *MessageInfo) extensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
@ -34,7 +35,17 @@ func (mi *MessageInfo) extensionFieldInfo(xt pref.ExtensionType) *extensionField
tagsize: wire.SizeVarint(wiretag),
funcs: encoderFuncsForValue(xt, xt.GoType()),
}
// Does the unmarshal function need a value passed to it?
// This is true for composite types, where we pass in a message, list, or map to fill in,
// and for enums, where we pass in a prototype value to specify the concrete enum type.
switch xt.Kind() {
case pref.MessageKind, pref.GroupKind, pref.EnumKind:
e.unmarshalNeedsValue = true
default:
if xt.Cardinality() == pref.Repeated {
e.unmarshalNeedsValue = true
}
}
mi.extensionFieldInfosMu.Lock()
if mi.extensionFieldInfos == nil {
mi.extensionFieldInfos = make(map[pref.ExtensionType]*extensionFieldInfo)

View File

@ -5,7 +5,6 @@
package impl
import (
"fmt"
"reflect"
"unicode/utf8"
@ -19,61 +18,59 @@ type errInvalidUTF8 struct{}
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true }
func makeOneofFieldCoder(fs reflect.StructField, od pref.OneofDescriptor, structFields map[pref.FieldNumber]reflect.StructField, otypes map[pref.FieldNumber]reflect.Type) pointerCoderFuncs {
type oneofFieldInfo struct {
wiretag uint64
tagsize int
funcs pointerCoderFuncs
}
oneofFieldInfos := make(map[reflect.Type]oneofFieldInfo)
for i, fields := 0, od.Fields(); i < fields.Len(); i++ {
fd := fields.Get(i)
ot := otypes[fd.Number()]
wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
oneofFieldInfos[ot] = oneofFieldInfo{
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
funcs: fieldCoder(fd, ot.Field(0).Type),
}
}
func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFuncs {
ot := si.oneofWrappersByNumber[fd.Number()]
funcs := fieldCoder(fd, ot.Field(0).Type)
fs := si.oneofsByName[fd.ContainingOneof().Name()]
ft := fs.Type
getInfo := func(p pointer) (pointer, oneofFieldInfo) {
wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
tagsize := wire.SizeVarint(wiretag)
getInfo := func(p pointer) (pointer, bool) {
v := p.AsValueOf(ft).Elem()
if v.IsNil() {
return pointer{}, oneofFieldInfo{}
return pointer{}, false
}
v = v.Elem() // interface -> *struct
telem := v.Elem().Type()
info, ok := oneofFieldInfos[telem]
if !ok {
panic(fmt.Errorf("invalid oneof type %v", telem))
if v.Elem().Type() != ot {
return pointer{}, false
}
return pointerOfValue(v).Apply(zeroOffset), info
return pointerOfValue(v).Apply(zeroOffset), true
}
return pointerCoderFuncs{
pcf := pointerCoderFuncs{
size: func(p pointer, _ int, opts marshalOptions) int {
v, info := getInfo(p)
if info.funcs.size == nil {
v, ok := getInfo(p)
if !ok {
return 0
}
return info.funcs.size(v, info.tagsize, opts)
return funcs.size(v, tagsize, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
v, info := getInfo(p)
if info.funcs.marshal == nil {
v, ok := getInfo(p)
if !ok {
return b, nil
}
return info.funcs.marshal(b, v, info.wiretag, opts)
return funcs.marshal(b, v, wiretag, opts)
},
isInit: func(p pointer) error {
v, info := getInfo(p)
if info.funcs.isInit == nil {
return nil
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
v := reflect.New(ot)
n, err := funcs.unmarshal(b, pointerOfValue(v).Apply(zeroOffset), wtyp, opts)
if err != nil {
return 0, err
}
return info.funcs.isInit(v)
p.AsValueOf(ft).Elem().Set(v)
return n, nil
},
}
if funcs.isInit != nil {
pcf.isInit = func(p pointer) error {
v, ok := getInfo(p)
if !ok {
return nil
}
return funcs.isInit(v)
}
}
return pcf
}
func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@ -85,6 +82,9 @@ func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCode
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageInfo(b, p, wiretag, fi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
return consumeMessageInfo(b, p, fi, wtyp, opts)
},
isInit: func(p pointer) error {
return fi.isInitializedPointer(p.Elem())
},
@ -99,6 +99,13 @@ func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCode
m := asMessage(p.AsValueOf(ft).Elem())
return appendMessage(b, m, wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
}
return consumeMessage(b, asMessage(mp), wtyp, opts)
},
isInit: func(p pointer) error {
m := asMessage(p.AsValueOf(ft).Elem())
return proto.IsInitialized(m)
@ -117,6 +124,23 @@ func appendMessageInfo(b []byte, p pointer, wiretag uint64, mi *MessageInfo, opt
return mi.marshalAppendPointer(b, p.Elem(), opts)
}
func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(mi.GoType.Elem())))
}
if _, err := mi.unmarshalPointer(v, p.Elem(), 0, opts); err != nil {
return 0, err
}
return n, nil
}
func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
return wire.SizeBytes(proto.Size(m)) + tagsize
}
@ -127,6 +151,20 @@ func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOption
return opts.Options().MarshalAppend(b, m)
}
func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOptions) (int, error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if err := opts.Options().Unmarshal(v, m); err != nil {
return 0, err
}
return n, nil
}
func sizeMessageIface(ival interface{}, tagsize int, opts marshalOptions) int {
m := Export{}.MessageOf(ival).Interface()
return sizeMessage(m, tagsize, opts)
@ -137,18 +175,26 @@ func appendMessageIface(b []byte, ival interface{}, wiretag uint64, opts marshal
return appendMessage(b, m, wiretag, opts)
}
func consumeMessageIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
m := Export{}.MessageOf(ival).Interface()
n, err := consumeMessage(b, m, wtyp, opts)
return ival, n, err
}
func isInitMessageIface(ival interface{}) error {
m := Export{}.MessageOf(ival).Interface()
return proto.IsInitialized(m)
}
var coderMessageIface = ifaceCoderFuncs{
size: sizeMessageIface,
marshal: appendMessageIface,
isInit: isInitMessageIface,
size: sizeMessageIface,
marshal: appendMessageIface,
unmarshal: consumeMessageIface,
isInit: isInitMessageIface,
}
func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
num := fd.Number()
if fi, ok := getMessageInfo(ft); ok {
return pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
@ -157,6 +203,9 @@ func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderF
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupType(b, p, wiretag, fi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
return consumeGroupType(b, p, fi, num, wtyp, opts)
},
isInit: func(p pointer) error {
return fi.isInitializedPointer(p.Elem())
},
@ -171,6 +220,13 @@ func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderF
m := asMessage(p.AsValueOf(ft).Elem())
return appendGroup(b, m, wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
}
return consumeGroup(b, asMessage(mp), num, wtyp, opts)
},
isInit: func(p pointer) error {
m := asMessage(p.AsValueOf(ft).Elem())
return proto.IsInitialized(m)
@ -190,6 +246,16 @@ func appendGroupType(b []byte, p pointer, wiretag uint64, mi *MessageInfo, opts
return b, err
}
func consumeGroupType(b []byte, p pointer, mi *MessageInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(mi.GoType.Elem())))
}
return mi.unmarshalPointer(b, p.Elem(), num, opts)
}
func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int {
return 2*tagsize + proto.Size(m)
}
@ -201,30 +267,47 @@ func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions)
return b, err
}
func sizeGroupIface(ival interface{}, tagsize int, opts marshalOptions) int {
m := Export{}.MessageOf(ival).Interface()
return sizeGroup(m, tagsize, opts)
func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
return 0, wire.ParseError(n)
}
return n, opts.Options().Unmarshal(b, m)
}
func appendGroupIface(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error) {
m := Export{}.MessageOf(ival).Interface()
return appendGroup(b, m, wiretag, opts)
}
var coderGroupIface = ifaceCoderFuncs{
size: sizeGroupIface,
marshal: appendGroupIface,
isInit: isInitMessageIface,
func makeGroupValueCoder(fd pref.FieldDescriptor, ft reflect.Type) ifaceCoderFuncs {
return ifaceCoderFuncs{
size: func(ival interface{}, tagsize int, opts marshalOptions) int {
m := Export{}.MessageOf(ival).Interface()
return sizeGroup(m, tagsize, opts)
},
marshal: func(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error) {
m := Export{}.MessageOf(ival).Interface()
return appendGroup(b, m, wiretag, opts)
},
unmarshal: func(b []byte, ival interface{}, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
m := Export{}.MessageOf(ival).Interface()
n, err := consumeGroup(b, m, num, wtyp, opts)
return ival, n, err
},
isInit: isInitMessageIface,
}
}
func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
if fi, ok := getMessageInfo(ft); ok {
return pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
return sizeMessageSliceInfo(p, fi, tagsize, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSliceInfo(b, p, wiretag, fi, opts)
},
size: func(p pointer, tagsize int, opts marshalOptions) int {
return sizeMessageSliceInfo(p, fi, tagsize, opts)
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
return consumeMessageSliceInfo(b, p, fi, wtyp, opts)
},
isInit: func(p pointer) error {
return isInitMessageSliceInfo(p, fi)
@ -238,6 +321,9 @@ func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointe
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSlice(b, p, wiretag, ft, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
return consumeMessageSlice(b, p, ft, wtyp, opts)
},
isInit: func(p pointer) error {
return isInitMessageSlice(p, ft)
},
@ -268,6 +354,23 @@ func appendMessageSliceInfo(b []byte, p pointer, wiretag uint64, mi *MessageInfo
return b, nil
}
func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
m := reflect.New(mi.GoType.Elem()).Interface()
mp := pointerOfIface(m)
if _, err := mi.unmarshalPointer(v, mp, 0, opts); err != nil {
return 0, err
}
p.AppendPointerSlice(mp)
return n, nil
}
func isInitMessageSliceInfo(p pointer, mi *MessageInfo) error {
s := p.PointerSlice()
for _, v := range s {
@ -282,7 +385,7 @@ func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOpti
s := p.PointerSlice()
n := 0
for _, v := range s {
m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
m := asMessage(v.AsValueOf(goType.Elem()))
n += wire.SizeBytes(proto.Size(m)) + tagsize
}
return n
@ -292,7 +395,7 @@ func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type
s := p.PointerSlice()
var err error
for _, v := range s {
m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
m := asMessage(v.AsValueOf(goType.Elem()))
b = wire.AppendVarint(b, wiretag)
siz := proto.Size(m)
b = wire.AppendVarint(b, uint64(siz))
@ -304,10 +407,26 @@ func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type
return b, nil
}
func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Type, opts unmarshalOptions) (int, error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(v, asMessage(mp)); err != nil {
return 0, err
}
p.AppendPointerSlice(pointerOfValue(mp))
return n, nil
}
func isInitMessageSlice(p pointer, goType reflect.Type) error {
s := p.PointerSlice()
for _, v := range s {
m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
m := asMessage(v.AsValueOf(goType.Elem()))
if err := proto.IsInitialized(m); err != nil {
return err
}
@ -327,18 +446,26 @@ func appendMessageSliceIface(b []byte, ival interface{}, wiretag uint64, opts ma
return appendMessageSlice(b, p, wiretag, reflect.TypeOf(ival).Elem().Elem(), opts)
}
func consumeMessageSliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
p := pointerOfIface(ival)
n, err := consumeMessageSlice(b, p, reflect.TypeOf(ival).Elem().Elem(), wtyp, opts)
return ival, n, err
}
func isInitMessageSliceIface(ival interface{}) error {
p := pointerOfIface(ival)
return isInitMessageSlice(p, reflect.TypeOf(ival).Elem().Elem())
}
var coderMessageSliceIface = ifaceCoderFuncs{
size: sizeMessageSliceIface,
marshal: appendMessageSliceIface,
isInit: isInitMessageSliceIface,
size: sizeMessageSliceIface,
marshal: appendMessageSliceIface,
unmarshal: consumeMessageSliceIface,
isInit: isInitMessageSliceIface,
}
func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
num := fd.Number()
if fi, ok := getMessageInfo(ft); ok {
return pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
@ -347,6 +474,9 @@ func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerC
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSliceInfo(b, p, wiretag, fi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
return consumeGroupSliceInfo(b, p, num, wtyp, fi, opts)
},
isInit: func(p pointer) error {
return isInitMessageSliceInfo(p, fi)
},
@ -359,6 +489,9 @@ func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerC
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSlice(b, p, wiretag, ft, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
return consumeGroupSlice(b, p, num, wtyp, ft, opts)
},
isInit: func(p pointer) error {
return isInitMessageSlice(p, ft)
},
@ -369,7 +502,7 @@ func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, _ marshalO
s := p.PointerSlice()
n := 0
for _, v := range s {
m := Export{}.MessageOf(v.AsValueOf(messageType.Elem()).Interface()).Interface()
m := asMessage(v.AsValueOf(messageType.Elem()))
n += 2*tagsize + proto.Size(m)
}
return n
@ -379,7 +512,7 @@ func appendGroupSlice(b []byte, p pointer, wiretag uint64, messageType reflect.T
s := p.PointerSlice()
var err error
for _, v := range s {
m := Export{}.MessageOf(v.AsValueOf(messageType.Elem()).Interface()).Interface()
m := asMessage(v.AsValueOf(messageType.Elem()))
b = wire.AppendVarint(b, wiretag) // start group
b, err = opts.Options().MarshalAppend(b, m)
if err != nil {
@ -390,6 +523,22 @@ func appendGroupSlice(b []byte, p pointer, wiretag uint64, messageType reflect.T
return b, nil
}
func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goType reflect.Type, opts unmarshalOptions) (int, error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
return 0, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(b, asMessage(mp)); err != nil {
return 0, err
}
p.AppendPointerSlice(pointerOfValue(mp))
return n, nil
}
func sizeGroupSliceInfo(p pointer, mi *MessageInfo, tagsize int, opts marshalOptions) int {
s := p.PointerSlice()
n := 0
@ -413,6 +562,20 @@ func appendGroupSliceInfo(b []byte, p pointer, wiretag uint64, mi *MessageInfo,
return b, nil
}
func consumeGroupSliceInfo(b []byte, p pointer, num wire.Number, wtyp wire.Type, mi *MessageInfo, opts unmarshalOptions) (int, error) {
if wtyp != wire.StartGroupType {
return 0, errUnknown
}
m := reflect.New(mi.GoType.Elem()).Interface()
mp := pointerOfIface(m)
n, err := mi.unmarshalPointer(b, mp, num, opts)
if err != nil {
return 0, err
}
p.AppendPointerSlice(mp)
return n, nil
}
func sizeGroupSliceIface(ival interface{}, tagsize int, opts marshalOptions) int {
p := pointerOfIface(ival)
return sizeGroupSlice(p, reflect.TypeOf(ival).Elem().Elem(), tagsize, opts)
@ -423,10 +586,17 @@ func appendGroupSliceIface(b []byte, ival interface{}, wiretag uint64, opts mars
return appendGroupSlice(b, p, wiretag, reflect.TypeOf(ival).Elem().Elem(), opts)
}
func consumeGroupSliceIface(b []byte, ival interface{}, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
p := pointerOfIface(ival)
n, err := consumeGroupSlice(b, p, num, wtyp, reflect.TypeOf(ival).Elem().Elem(), opts)
return ival, n, err
}
var coderGroupSliceIface = ifaceCoderFuncs{
size: sizeGroupSliceIface,
marshal: appendGroupSliceIface,
isInit: isInitMessageSliceIface,
size: sizeGroupSliceIface,
marshal: appendGroupSliceIface,
unmarshal: consumeGroupSliceIface,
isInit: isInitMessageSliceIface,
}
// Enums
@ -443,9 +613,23 @@ func appendEnumIface(b []byte, ival interface{}, wiretag uint64, _ marshalOption
return b, nil
}
func consumeEnumIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
if wtyp != wire.VarintType {
return nil, 0, errUnknown
}
v, n := wire.ConsumeVarint(b)
if n < 0 {
return nil, 0, wire.ParseError(n)
}
rv := reflect.New(reflect.TypeOf(ival)).Elem()
rv.SetInt(int64(v))
return rv.Interface(), n, nil
}
var coderEnumIface = ifaceCoderFuncs{
size: sizeEnumIface,
marshal: appendEnumIface,
size: sizeEnumIface,
marshal: appendEnumIface,
unmarshal: consumeEnumIface,
}
func sizeEnumSliceIface(ival interface{}, tagsize int, opts marshalOptions) (size int) {
@ -471,9 +655,47 @@ func appendEnumSliceReflect(b []byte, s reflect.Value, wiretag uint64, opts mars
return b, nil
}
func consumeEnumSliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
n, err := consumeEnumSliceReflect(b, reflect.ValueOf(ival), wtyp, opts)
return ival, n, err
}
func consumeEnumSliceReflect(b []byte, s reflect.Value, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
s = s.Elem() // *[]E -> []E
if wtyp == wire.BytesType {
b, n = wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
for len(b) > 0 {
v, n := wire.ConsumeVarint(b)
if n < 0 {
return 0, wire.ParseError(n)
}
rv := reflect.New(s.Type().Elem()).Elem()
rv.SetInt(int64(v))
s.Set(reflect.Append(s, rv))
b = b[n:]
}
return n, nil
}
if wtyp != wire.VarintType {
return 0, errUnknown
}
v, n := wire.ConsumeVarint(b)
if n < 0 {
return 0, wire.ParseError(n)
}
rv := reflect.New(s.Type().Elem()).Elem()
rv.SetInt(int64(v))
s.Set(reflect.Append(s, rv))
return n, nil
}
var coderEnumSliceIface = ifaceCoderFuncs{
size: sizeEnumSliceIface,
marshal: appendEnumSliceIface,
size: sizeEnumSliceIface,
marshal: appendEnumSliceIface,
unmarshal: consumeEnumSliceIface,
}
// Strings with UTF8 validation.
@ -488,9 +710,25 @@ func appendStringValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOpti
return b, nil
}
func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
v, n := wire.ConsumeString(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return 0, errInvalidUTF8{}
}
*p.String() = v
return n, nil
}
var coderStringValidateUTF8 = pointerCoderFuncs{
size: sizeString,
marshal: appendStringValidateUTF8,
size: sizeString,
marshal: appendStringValidateUTF8,
unmarshal: consumeStringValidateUTF8,
}
func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
@ -507,8 +745,9 @@ func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marsh
}
var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{
size: sizeStringNoZero,
marshal: appendStringNoZeroValidateUTF8,
size: sizeStringNoZero,
marshal: appendStringNoZeroValidateUTF8,
unmarshal: consumeStringValidateUTF8,
}
func sizeStringSliceValidateUTF8(p pointer, tagsize int, _ marshalOptions) (size int) {
@ -526,15 +765,32 @@ func appendStringSliceValidateUTF8(b []byte, p pointer, wiretag uint64, _ marsha
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
err = errInvalidUTF8{}
return b, errInvalidUTF8{}
}
}
return b, err
}
func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
sp := p.StringSlice()
v, n := wire.ConsumeString(b)
if n < 0 {
return 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return 0, errInvalidUTF8{}
}
*sp = append(*sp, v)
return n, nil
}
var coderStringSliceValidateUTF8 = pointerCoderFuncs{
size: sizeStringSliceValidateUTF8,
marshal: appendStringSliceValidateUTF8,
size: sizeStringSliceValidateUTF8,
marshal: appendStringSliceValidateUTF8,
unmarshal: consumeStringSliceValidateUTF8,
}
func sizeStringIfaceValidateUTF8(ival interface{}, tagsize int, _ marshalOptions) int {
@ -552,9 +808,24 @@ func appendStringIfaceValidateUTF8(b []byte, ival interface{}, wiretag uint64, _
return b, nil
}
func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
if wtyp != wire.BytesType {
return nil, 0, errUnknown
}
v, n := wire.ConsumeString(b)
if n < 0 {
return nil, 0, wire.ParseError(n)
}
if !utf8.ValidString(v) {
return nil, 0, errInvalidUTF8{}
}
return v, n, nil
}
var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{
size: sizeStringIfaceValidateUTF8,
marshal: appendStringIfaceValidateUTF8,
size: sizeStringIfaceValidateUTF8,
marshal: appendStringIfaceValidateUTF8,
unmarshal: consumeStringIfaceValidateUTF8,
}
func asMessage(v reflect.Value) pref.ProtoMessage {

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,17 @@ import (
var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
type mapInfo struct {
goType reflect.Type
keyWiretag uint64
valWiretag uint64
keyFuncs ifaceCoderFuncs
valFuncs ifaceCoderFuncs
keyZero interface{}
valZero interface{}
newVal func() interface{}
}
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
// TODO: Consider generating specialized map coders.
keyField := fd.MapKey()
@ -25,6 +36,22 @@ func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointer
keyFuncs := encoderFuncsForValue(keyField, ft.Key())
valFuncs := encoderFuncsForValue(valField, ft.Elem())
mapi := &mapInfo{
goType: ft,
keyWiretag: keyWiretag,
valWiretag: valWiretag,
keyFuncs: keyFuncs,
valFuncs: valFuncs,
keyZero: reflect.Zero(ft.Key()).Interface(),
valZero: reflect.Zero(ft.Elem()).Interface(),
}
switch valField.Kind() {
case pref.GroupKind, pref.MessageKind:
mapi.newVal = func() interface{} {
return reflect.New(ft.Elem().Elem()).Interface()
}
}
funcs = pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
@ -32,6 +59,9 @@ func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointer
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
return consumeMap(b, p, wtyp, mapi, opts)
},
}
if valFuncs.isInit != nil {
funcs.isInit = func(p pointer) error {
@ -46,6 +76,64 @@ const (
mapValTagSize = 1 // field 2, tag size 2.
)
func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
mp := p.AsValueOf(mapi.goType)
if mp.Elem().IsNil() {
mp.Elem().Set(reflect.MakeMap(mapi.goType))
}
m := mp.Elem()
if wtyp != wire.BytesType {
return 0, errUnknown
}
b, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
var (
key = mapi.keyZero
val = mapi.valZero
)
if mapi.newVal != nil {
val = mapi.newVal()
}
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
}
b = b[n:]
err := errUnknown
switch num {
case 1:
var v interface{}
v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
case 2:
var v interface{}
v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
if err != nil {
break
}
val = v
}
if err == errUnknown {
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
}
} else if err != nil {
return 0, err
}
b = b[n:]
}
m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
return n, nil
}
func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
m := p.AsValueOf(goType).Elem()
n := 0

View File

@ -18,6 +18,8 @@ import (
// possible.
type coderMessageInfo struct {
orderedCoderFields []*coderFieldInfo
denseCoderFields []*coderFieldInfo
coderFields map[wire.Number]*coderFieldInfo
sizecacheOffset offset
unknownOffset offset
extensionOffset offset
@ -39,13 +41,14 @@ func (mi *MessageInfo) makeMethods(t reflect.Type, si structInfo) {
mi.unknownOffset = si.unknownOffset
mi.extensionOffset = si.extensionOffset
mi.coderFields = make(map[wire.Number]*coderFieldInfo)
for i := 0; i < mi.PBType.Descriptor().Fields().Len(); i++ {
fd := mi.PBType.Descriptor().Fields().Get(i)
if fd.ContainingOneof() != nil {
continue
}
fs := si.fieldsByNumber[fd.Number()]
if fd.ContainingOneof() != nil {
fs = si.oneofsByName[fd.ContainingOneof().Name()]
}
ft := fs.Type
var wiretag uint64
if !fd.IsPacked() {
@ -53,37 +56,51 @@ func (mi *MessageInfo) makeMethods(t reflect.Type, si structInfo) {
} else {
wiretag = wire.EncodeTag(fd.Number(), wire.BytesType)
}
mi.orderedCoderFields = append(mi.orderedCoderFields, &coderFieldInfo{
var funcs pointerCoderFuncs
if fd.ContainingOneof() != nil {
funcs = makeOneofFieldCoder(si, fd)
} else {
funcs = fieldCoder(fd, ft)
}
cf := &coderFieldInfo{
num: fd.Number(),
offset: offsetOf(fs, mi.Exporter),
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
funcs: fieldCoder(fd, ft),
funcs: funcs,
isPointer: (fd.Cardinality() == pref.Repeated ||
fd.Kind() == pref.MessageKind ||
fd.Kind() == pref.GroupKind ||
fd.Syntax() != pref.Proto3),
isRequired: fd.Cardinality() == pref.Required,
})
}
for i := 0; i < mi.PBType.Descriptor().Oneofs().Len(); i++ {
od := mi.PBType.Descriptor().Oneofs().Get(i)
fs := si.oneofsByName[od.Name()]
mi.orderedCoderFields = append(mi.orderedCoderFields, &coderFieldInfo{
num: od.Fields().Get(0).Number(),
offset: offsetOf(fs, mi.Exporter),
funcs: makeOneofFieldCoder(fs, od, si.fieldsByNumber, si.oneofWrappersByNumber),
isPointer: true,
})
}
mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
mi.coderFields[cf.num] = cf
}
sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
})
var maxDense pref.FieldNumber
for _, cf := range mi.orderedCoderFields {
if cf.num >= 16 && cf.num >= 2*maxDense {
break
}
maxDense = cf.num
}
mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1)
for _, cf := range mi.orderedCoderFields {
if int(cf.num) > len(mi.denseCoderFields) {
break
}
mi.denseCoderFields[cf.num] = cf
}
mi.needsInitCheck = needsInitCheck(mi.PBType)
mi.methods = piface.Methods{
Flags: piface.MethodFlagDeterministicMarshal,
MarshalAppend: mi.marshalAppend,
Unmarshal: mi.unmarshal,
Size: mi.size,
IsInitialized: mi.isInitialized,
}

View File

@ -7,6 +7,8 @@
package impl
import (
"reflect"
"google.golang.org/protobuf/internal/encoding/wire"
)
@ -22,9 +24,22 @@ func appendEnum(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byt
return b, nil
}
func consumeEnum(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
if wtyp != wire.VarintType {
return 0, errUnknown
}
v, n := wire.ConsumeVarint(b)
if n < 0 {
return 0, wire.ParseError(n)
}
p.v.Elem().SetInt(int64(v))
return n, nil
}
var coderEnum = pointerCoderFuncs{
size: sizeEnum,
marshal: appendEnum,
size: sizeEnum,
marshal: appendEnum,
unmarshal: consumeEnum,
}
func sizeEnumNoZero(p pointer, tagsize int, opts marshalOptions) (size int) {
@ -42,8 +57,9 @@ func appendEnumNoZero(b []byte, p pointer, wiretag uint64, opts marshalOptions)
}
var coderEnumNoZero = pointerCoderFuncs{
size: sizeEnumNoZero,
marshal: appendEnumNoZero,
size: sizeEnumNoZero,
marshal: appendEnumNoZero,
unmarshal: consumeEnum,
}
func sizeEnumPtr(p pointer, tagsize int, opts marshalOptions) (size int) {
@ -54,9 +70,20 @@ func appendEnumPtr(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]
return appendEnum(b, pointer{p.v.Elem()}, wiretag, opts)
}
func consumeEnumPtr(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (n int, err error) {
if wtyp != wire.VarintType {
return 0, errUnknown
}
if p.v.Elem().IsNil() {
p.v.Elem().Set(reflect.New(p.v.Elem().Type().Elem()))
}
return consumeEnum(b, pointer{p.v.Elem()}, wtyp, opts)
}
var coderEnumPtr = pointerCoderFuncs{
size: sizeEnumPtr,
marshal: appendEnumPtr,
size: sizeEnumPtr,
marshal: appendEnumPtr,
unmarshal: consumeEnumPtr,
}
func sizeEnumSlice(p pointer, tagsize int, opts marshalOptions) (size int) {
@ -67,9 +94,14 @@ func appendEnumSlice(b []byte, p pointer, wiretag uint64, opts marshalOptions) (
return appendEnumSliceReflect(b, p.v.Elem(), wiretag, opts)
}
func consumeEnumSlice(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (n int, err error) {
return consumeEnumSliceReflect(b, p.v, wtyp, opts)
}
var coderEnumSlice = pointerCoderFuncs{
size: sizeEnumSlice,
marshal: appendEnumSlice,
size: sizeEnumSlice,
marshal: appendEnumSlice,
unmarshal: consumeEnumSlice,
}
func sizeEnumPackedSlice(p pointer, tagsize int, _ marshalOptions) (size int) {
@ -104,6 +136,7 @@ func appendEnumPackedSlice(b []byte, p pointer, wiretag uint64, opts marshalOpti
}
var coderEnumPackedSlice = pointerCoderFuncs{
size: sizeEnumPackedSlice,
marshal: appendEnumPackedSlice,
size: sizeEnumPackedSlice,
marshal: appendEnumPackedSlice,
unmarshal: consumeEnumSlice,
}

View File

@ -8,21 +8,24 @@ import (
"fmt"
"reflect"
"google.golang.org/protobuf/internal/encoding/wire"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
// pointerCoderFuncs is a set of pointer encoding functions.
type pointerCoderFuncs struct {
size func(p pointer, tagsize int, opts marshalOptions) int
marshal func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error)
isInit func(p pointer) error
size func(p pointer, tagsize int, opts marshalOptions) int
marshal func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error)
unmarshal func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error)
isInit func(p pointer) error
}
// ifaceCoderFuncs is a set of interface{} encoding functions.
type ifaceCoderFuncs struct {
size func(ival interface{}, tagsize int, opts marshalOptions) int
marshal func(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error)
isInit func(ival interface{}) error
size func(ival interface{}, tagsize int, opts marshalOptions) int
marshal func(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error)
unmarshal func(b []byte, ival interface{}, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error)
isInit func(ival interface{}) error
}
// fieldCoder returns pointer functions for a field, used for operating on
@ -574,7 +577,7 @@ func encoderFuncsForValue(fd pref.FieldDescriptor, ft reflect.Type) ifaceCoderFu
case pref.MessageKind:
return coderMessageIface
case pref.GroupKind:
return coderGroupIface
return makeGroupValueCoder(fd, ft)
}
}
panic(fmt.Errorf("invalid type: no encoder for %v %v %v/%v", fd.FullName(), fd.Cardinality(), fd.Kind(), ft))

162
internal/impl/decode.go Normal file
View File

@ -0,0 +1,162 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
piface "google.golang.org/protobuf/runtime/protoiface"
)
// unmarshalOptions is a more efficient representation of UnmarshalOptions.
//
// We don't preserve the AllowPartial flag, because fast-path (un)marshal
// operations always allow partial messages.
type unmarshalOptions struct {
flags unmarshalOptionFlags
resolver preg.ExtensionTypeResolver
}
type unmarshalOptionFlags uint8
const (
unmarshalDiscardUnknown unmarshalOptionFlags = 1 << iota
)
func newUnmarshalOptions(opts piface.UnmarshalOptions) unmarshalOptions {
o := unmarshalOptions{
resolver: opts.Resolver,
}
if opts.DiscardUnknown {
o.flags |= unmarshalDiscardUnknown
}
return o
}
func (o unmarshalOptions) Options() proto.UnmarshalOptions {
return proto.UnmarshalOptions{
AllowPartial: true,
DiscardUnknown: o.DiscardUnknown(),
Resolver: o.Resolver(),
}
}
func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&unmarshalDiscardUnknown != 0 }
func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver }
// unmarshal is protoreflect.Methods.Unmarshal.
func (mi *MessageInfo) unmarshal(b []byte, m pref.ProtoMessage, opts piface.UnmarshalOptions) error {
_, err := mi.unmarshalPointer(b, pointerOfIface(m), 0, newUnmarshalOptions(opts))
return err
}
// errUnknown is returned during unmarshaling to indicate a parse error that
// should result in a field being placed in the unknown fields section (for example,
// when the wire type doesn't match) as opposed to the entire unmarshal operation
// failing (for example, when a field extends past the available input).
//
// This is a sentinel error which should never be visible to the user.
var errUnknown = errors.New("unknown")
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Number, opts unmarshalOptions) (int, error) {
mi.init()
var exts *map[int32]ExtensionField
start := len(b)
for len(b) > 0 {
// Parse the tag (field number and wire type).
// TODO: inline 1 and 2 byte variants?
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
}
b = b[n:]
var f *coderFieldInfo
if int(num) < len(mi.denseCoderFields) {
f = mi.denseCoderFields[num]
} else {
f = mi.coderFields[num]
}
err := errUnknown
switch {
case f != nil:
if f.funcs.unmarshal == nil {
break
}
n, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
case num == groupTag && wtyp == wire.EndGroupType:
// End of group.
return start - len(b), nil
default:
// Possible extension.
if exts == nil && mi.extensionOffset.IsValid() {
exts = p.Apply(mi.extensionOffset).Extensions()
if *exts == nil {
*exts = make(map[int32]ExtensionField)
}
}
if exts == nil {
break
}
n, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
}
if err != nil {
if err != errUnknown {
return 0, err
}
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
}
if mi.unknownOffset.IsValid() {
u := p.Apply(mi.unknownOffset).Bytes()
*u = wire.AppendTag(*u, num, wtyp)
*u = append(*u, b[:n]...)
}
}
b = b[n:]
}
if groupTag != 0 {
return 0, errors.New("missing end group marker")
}
return start, nil
}
func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (n int, err error) {
x := exts[int32(num)]
xt := x.GetType()
if xt == nil {
var err error
xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.FullName(), num)
if err != nil {
if err == preg.NotFound {
return 0, errUnknown
}
return 0, err
}
x.SetType(xt)
}
xi := mi.extensionFieldInfo(xt)
if xi.funcs.unmarshal == nil {
return 0, errUnknown
}
ival := x.GetValue()
if ival == nil && xi.unmarshalNeedsValue {
// Create a new message, list, or map value to fill in.
// For enums, create a prototype value to let the unmarshal func know the
// concrete type.
ival = xt.InterfaceOf(xt.New())
}
v, n, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
if err != nil {
return 0, err
}
x.SetEagerValue(v)
exts[int32(num)] = x
return n, nil
}

View File

@ -13,19 +13,19 @@ import (
piface "google.golang.org/protobuf/runtime/protoiface"
)
// marshalOptions is a more efficient representation of MarshalOptions.
//
// We don't preserve the AllowPartial flag, because fast-path (un)marshal
// operations always allow partial messages.
type marshalOptions uint
const (
marshalAllowPartial marshalOptions = 1 << iota
marshalDeterministic
marshalDeterministic marshalOptions = 1 << iota
marshalUseCachedSize
)
func newMarshalOptions(opts piface.MarshalOptions) marshalOptions {
var o marshalOptions
if opts.AllowPartial {
o |= marshalAllowPartial
}
if opts.Deterministic {
o |= marshalDeterministic
}
@ -37,13 +37,12 @@ func newMarshalOptions(opts piface.MarshalOptions) marshalOptions {
func (o marshalOptions) Options() proto.MarshalOptions {
return proto.MarshalOptions{
AllowPartial: o.AllowPartial(),
AllowPartial: true,
Deterministic: o.Deterministic(),
UseCachedSize: o.UseCachedSize(),
}
}
func (o marshalOptions) AllowPartial() bool { return o&marshalAllowPartial != 0 }
func (o marshalOptions) Deterministic() bool { return o&marshalDeterministic != 0 }
func (o marshalOptions) UseCachedSize() bool { return o&marshalUseCachedSize != 0 }

View File

@ -516,6 +516,7 @@ func (m *messageIfaceWrapper) XXX_Methods() *piface.Methods {
return &piface.Methods{
Flags: piface.MethodFlagDeterministicMarshal,
MarshalAppend: m.marshalAppend,
Unmarshal: m.unmarshal,
Size: m.size,
IsInitialized: m.isInitialized,
}
@ -526,9 +527,13 @@ func (m *messageIfaceWrapper) ProtoUnwrap() interface{} {
func (m *messageIfaceWrapper) marshalAppend(b []byte, _ pref.ProtoMessage, opts piface.MarshalOptions) ([]byte, error) {
return m.mi.marshalAppendPointer(b, m.p, newMarshalOptions(opts))
}
func (m *messageIfaceWrapper) unmarshal(b []byte, _ pref.ProtoMessage, opts piface.UnmarshalOptions) error {
_, err := m.mi.unmarshalPointer(b, m.p, 0, newUnmarshalOptions(opts))
return err
}
func (m *messageIfaceWrapper) size(msg pref.ProtoMessage) (size int) {
return m.mi.sizePointer(m.p, 0)
}
func (m *messageIfaceWrapper) isInitialized(msg pref.ProtoMessage) error {
func (m *messageIfaceWrapper) isInitialized(_ pref.ProtoMessage) error {
return m.mi.isInitializedPointer(m.p)
}

View File

@ -135,3 +135,14 @@ func (p pointer) PointerSlice() []pointer {
}
return s
}
// AppendPointerSlice appends v to p, which must be a []*T.
func (p pointer) AppendPointerSlice(v pointer) {
sp := p.v.Elem()
sp.Set(reflect.Append(sp, v.v))
}
// SetPointer sets *p to v.
func (p pointer) SetPointer(v pointer) {
p.v.Elem().Set(v.v)
}

View File

@ -115,3 +115,13 @@ func (p pointer) PointerSlice() []pointer {
// message type. We load it as []pointer.
return *(*[]pointer)(p.p)
}
// AppendPointerSlice appends v to p, which must be a []*T.
func (p pointer) AppendPointerSlice(v pointer) {
*(*[]pointer)(p.p) = append(*(*[]pointer)(p.p), v)
}
// SetPointer sets *p to v.
func (p pointer) SetPointer(v pointer) {
*(*unsafe.Pointer)(p.p) = (unsafe.Pointer)(v.p)
}