internal/impl: faster oneof marshaling

Change size, marshal, and isinit operations on oneofs to look up the
currently-set oneof type in a map rather than testing for each possible
oneof field in turn.

Significantly improves oneof encoding speed for oneofs with a
substantial number of fields:

  go test ./proto -bench=./oneof.*string.*test.TestAll -benchmem -count=8 -cpu=1

  name                                        old time/op    new time/op    delta
  Encode/oneof_(string)_(*test.TestAllTypes)     911ns ± 1%     397ns ± 3%  -56.45%  (p=0.000 n=8+7)
  Decode/oneof_(string)_(*test.TestAllTypes)     899ns ± 1%     922ns ± 1%   +2.49%  (p=0.001 n=7+7)

Fixes golang/protobuf#950

Change-Id: I9393a87975ce09011d885a8af4a63a639ea8452f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210281
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-12-05 16:36:28 -08:00
parent 79571e90e2
commit ce413af0b3
3 changed files with 76 additions and 50 deletions

View File

@ -20,40 +20,39 @@ type errInvalidUTF8 struct{}
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true }
func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFuncs {
ot := si.oneofWrappersByNumber[fd.Number()]
funcs := fieldCoder(fd, ot.Field(0).Type)
fs := si.oneofsByName[fd.ContainingOneof().Name()]
ft := fs.Type
wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
tagsize := wire.SizeVarint(wiretag)
getInfo := func(p pointer) (pointer, bool) {
v := p.AsValueOf(ft).Elem()
if v.IsNil() {
return pointer{}, false
}
v = v.Elem() // interface -> *struct
if v.IsNil() || v.Elem().Type() != ot {
return pointer{}, false
}
return pointerOfValue(v).Apply(zeroOffset), true
// initOneofFieldCoders initializes the fast-path functions for the fields in a oneof.
//
// For size, marshal, and isInit operations, functions are set only on the first field
// in the oneof. The functions are called when the oneof is non-nil, and will dispatch
// to the appropriate field-specific function as necessary.
//
// The unmarshal function is set on each field individually as usual.
func (mi *MessageInfo) initOneofFieldCoders(od pref.OneofDescriptor, si structInfo) {
type oneofFieldInfo struct {
wiretag uint64 // field tag (number + wire type)
tagsize int // size of the varint-encoded tag
funcs pointerCoderFuncs
}
pcf := pointerCoderFuncs{
size: func(p pointer, _ int, opts marshalOptions) int {
v, ok := getInfo(p)
if !ok {
return 0
}
return funcs.size(v, tagsize, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
v, ok := getInfo(p)
if !ok {
return b, nil
}
return funcs.marshal(b, v, wiretag, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
fs := si.oneofsByName[od.Name()]
ft := fs.Type
oneofFields := make(map[reflect.Type]*oneofFieldInfo)
needIsInit := false
fields := od.Fields()
for i, lim := 0, fields.Len(); i < lim; i++ {
fd := od.Fields().Get(i)
num := fd.Number()
cf := mi.coderFields[num]
ot := si.oneofWrappersByNumber[num]
funcs := fieldCoder(fd, ot.Field(0).Type)
oneofFields[ot] = &oneofFieldInfo{
wiretag: cf.wiretag,
tagsize: cf.tagsize,
funcs: funcs,
}
if funcs.isInit != nil {
needIsInit = true
}
cf.funcs.unmarshal = func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
var vw reflect.Value // pointer to wrapper type
vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
@ -67,18 +66,43 @@ func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFun
}
vi.Set(vw)
return n, nil
},
}
if funcs.isInit != nil {
pcf.isInit = func(p pointer) error {
v, ok := getInfo(p)
if !ok {
return nil
}
return funcs.isInit(v)
}
}
return pcf
getInfo := func(p pointer) (pointer, *oneofFieldInfo) {
v := p.AsValueOf(ft).Elem()
if v.IsNil() {
return pointer{}, nil
}
v = v.Elem() // interface -> *struct
if v.IsNil() {
return pointer{}, nil
}
return pointerOfValue(v).Apply(zeroOffset), oneofFields[v.Elem().Type()]
}
first := mi.coderFields[od.Fields().Get(0).Number()]
first.funcs.size = func(p pointer, tagsize int, opts marshalOptions) int {
p, info := getInfo(p)
if info == nil || info.funcs.size == nil {
return 0
}
return info.funcs.size(p, info.tagsize, opts)
}
first.funcs.marshal = func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
p, info := getInfo(p)
if info == nil || info.funcs.marshal == nil {
return b, nil
}
return info.funcs.marshal(b, p, info.wiretag, opts)
}
if needIsInit {
first.funcs.isInit = func(p pointer) error {
p, info := getInfo(p)
if info == nil || info.funcs.isInit == nil {
return nil
}
return info.funcs.isInit(p)
}
}
}
func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {

View File

@ -68,7 +68,6 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
switch {
case fd.ContainingOneof() != nil:
fieldOffset = offsetOf(fs, mi.Exporter)
funcs = makeOneofFieldCoder(fd, si)
case fd.IsWeak():
fieldOffset = si.weakOffset
funcs = makeWeakMessageFieldCoder(fd)
@ -91,6 +90,9 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
mi.coderFields[cf.num] = cf
}
for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
mi.initOneofFieldCoders(oneofs.Get(i), si)
}
if messageset.IsMessageSet(mi.Desc) {
if !mi.extensionOffset.IsValid() {
panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))

View File

@ -82,11 +82,11 @@ func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int
size += mi.sizeExtensions(e, opts)
}
for _, f := range mi.orderedCoderFields {
fptr := p.Apply(f.offset)
if f.isPointer && fptr.Elem().IsNil() {
if f.funcs.size == nil {
continue
}
if f.funcs.size == nil {
fptr := p.Apply(f.offset)
if f.isPointer && fptr.Elem().IsNil() {
continue
}
size += f.funcs.size(fptr, f.tagsize, opts)
@ -131,11 +131,11 @@ func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOpt
}
}
for _, f := range mi.orderedCoderFields {
fptr := p.Apply(f.offset)
if f.isPointer && fptr.Elem().IsNil() {
if f.funcs.marshal == nil {
continue
}
if f.funcs.marshal == nil {
fptr := p.Apply(f.offset)
if f.isPointer && fptr.Elem().IsNil() {
continue
}
b, err = f.funcs.marshal(b, fptr, f.wiretag, opts)