mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-01-29 09:32:38 +00:00
c600d6c086
Add a fast check for required fields to the fast path unmarshal. This is best-effort and will fail to detect some initialized messages: Messages with more than 64 required fields, messages split across multiple tags, possibly other cases. In the cases where it works (which is most of them in practice), this permits us to skip the IsInitialized check. Change-Id: I6b70953a333033a5e64fb7ca37a59786cb0f75a0 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215878 Reviewed-by: Joe Tsai <joetsai@google.com>
331 lines
8.6 KiB
Go
331 lines
8.6 KiB
Go
// Copyright 2019 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package impl
|
|
|
|
import (
|
|
"errors"
|
|
"reflect"
|
|
"sort"
|
|
|
|
"google.golang.org/protobuf/internal/encoding/wire"
|
|
pref "google.golang.org/protobuf/reflect/protoreflect"
|
|
)
|
|
|
|
type mapInfo struct {
|
|
goType reflect.Type
|
|
keyWiretag uint64
|
|
valWiretag uint64
|
|
keyFuncs valueCoderFuncs
|
|
valFuncs valueCoderFuncs
|
|
keyZero pref.Value
|
|
keyKind pref.Kind
|
|
valMessageInfo *MessageInfo
|
|
conv *mapConverter
|
|
}
|
|
|
|
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
|
|
// TODO: Consider generating specialized map coders.
|
|
keyField := fd.MapKey()
|
|
valField := fd.MapValue()
|
|
keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
|
|
valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
|
|
keyFuncs := encoderFuncsForValue(keyField)
|
|
valFuncs := encoderFuncsForValue(valField)
|
|
conv := newMapConverter(ft, fd)
|
|
|
|
mapi := &mapInfo{
|
|
goType: ft,
|
|
keyWiretag: keyWiretag,
|
|
valWiretag: valWiretag,
|
|
keyFuncs: keyFuncs,
|
|
valFuncs: valFuncs,
|
|
keyZero: keyField.Default(),
|
|
keyKind: keyField.Kind(),
|
|
conv: conv,
|
|
}
|
|
if valField.Kind() == pref.MessageKind {
|
|
mapi.valMessageInfo = getMessageInfo(ft.Elem())
|
|
}
|
|
|
|
funcs = pointerCoderFuncs{
|
|
size: func(p pointer, tagsize int, opts marshalOptions) int {
|
|
return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
|
|
},
|
|
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
|
|
return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
|
|
},
|
|
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
|
|
mp := p.AsValueOf(ft)
|
|
if mp.Elem().IsNil() {
|
|
mp.Elem().Set(reflect.MakeMap(mapi.goType))
|
|
}
|
|
if mapi.valMessageInfo == nil {
|
|
return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
|
|
} else {
|
|
return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
|
|
}
|
|
},
|
|
}
|
|
if valFuncs.isInit != nil {
|
|
funcs.isInit = func(p pointer) error {
|
|
return isInitMap(p.AsValueOf(ft).Elem(), mapi)
|
|
}
|
|
}
|
|
return funcs
|
|
}
|
|
|
|
const (
|
|
mapKeyTagSize = 1 // field 1, tag size 1.
|
|
mapValTagSize = 1 // field 2, tag size 2.
|
|
)
|
|
|
|
func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
|
|
if mapv.Len() == 0 {
|
|
return 0
|
|
}
|
|
n := 0
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
|
|
keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
var valSize int
|
|
value := mapi.conv.valConv.PBValueOf(iter.Value())
|
|
if mapi.valMessageInfo == nil {
|
|
valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
|
|
} else {
|
|
p := pointerOfValue(iter.Value())
|
|
valSize += mapValTagSize
|
|
valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
|
|
}
|
|
n += tagsize + wire.SizeBytes(keySize+valSize)
|
|
}
|
|
return n
|
|
}
|
|
|
|
func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
if wtyp != wire.BytesType {
|
|
return out, errUnknown
|
|
}
|
|
b, n := wire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, wire.ParseError(n)
|
|
}
|
|
var (
|
|
key = mapi.keyZero
|
|
val = mapi.conv.valConv.New()
|
|
)
|
|
for len(b) > 0 {
|
|
num, wtyp, n := wire.ConsumeTag(b)
|
|
if n < 0 {
|
|
return out, wire.ParseError(n)
|
|
}
|
|
if num > wire.MaxValidNumber {
|
|
return out, errors.New("invalid field number")
|
|
}
|
|
b = b[n:]
|
|
err := errUnknown
|
|
switch num {
|
|
case 1:
|
|
var v pref.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
key = v
|
|
n = o.n
|
|
case 2:
|
|
var v pref.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
val = v
|
|
n = o.n
|
|
}
|
|
if err == errUnknown {
|
|
n = wire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, wire.ParseError(n)
|
|
}
|
|
} else if err != nil {
|
|
return out, err
|
|
}
|
|
b = b[n:]
|
|
}
|
|
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
|
|
out.n = n
|
|
return out, nil
|
|
}
|
|
|
|
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
if wtyp != wire.BytesType {
|
|
return out, errUnknown
|
|
}
|
|
b, n := wire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, wire.ParseError(n)
|
|
}
|
|
var (
|
|
key = mapi.keyZero
|
|
val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
|
|
)
|
|
for len(b) > 0 {
|
|
num, wtyp, n := wire.ConsumeTag(b)
|
|
if n < 0 {
|
|
return out, wire.ParseError(n)
|
|
}
|
|
if num > wire.MaxValidNumber {
|
|
return out, errors.New("invalid field number")
|
|
}
|
|
b = b[n:]
|
|
err := errUnknown
|
|
switch num {
|
|
case 1:
|
|
var v pref.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
key = v
|
|
n = o.n
|
|
case 2:
|
|
if wtyp != wire.BytesType {
|
|
break
|
|
}
|
|
var v []byte
|
|
v, n = wire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, wire.ParseError(n)
|
|
}
|
|
var o unmarshalOutput
|
|
o, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
|
|
if o.initialized {
|
|
// Consider this map item initialized so long as we see
|
|
// an initialized value.
|
|
out.initialized = true
|
|
}
|
|
}
|
|
if err == errUnknown {
|
|
n = wire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, wire.ParseError(n)
|
|
}
|
|
} else if err != nil {
|
|
return out, err
|
|
}
|
|
b = b[n:]
|
|
}
|
|
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
|
|
out.n = n
|
|
return out, nil
|
|
}
|
|
|
|
func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
|
|
if mapi.valMessageInfo == nil {
|
|
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
|
|
val := mapi.conv.valConv.PBValueOf(valrv)
|
|
size := 0
|
|
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
size += mapi.valFuncs.size(val, mapValTagSize, opts)
|
|
b = wire.AppendVarint(b, uint64(size))
|
|
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
|
|
} else {
|
|
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
|
|
val := pointerOfValue(valrv)
|
|
valSize := mapi.valMessageInfo.sizePointer(val, opts)
|
|
size := 0
|
|
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
size += mapValTagSize + wire.SizeBytes(valSize)
|
|
b = wire.AppendVarint(b, uint64(size))
|
|
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b = wire.AppendVarint(b, mapi.valWiretag)
|
|
b = wire.AppendVarint(b, uint64(valSize))
|
|
return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
|
|
}
|
|
}
|
|
|
|
func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
|
|
if mapv.Len() == 0 {
|
|
return b, nil
|
|
}
|
|
if opts.Deterministic() {
|
|
return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
|
|
}
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
var err error
|
|
b = wire.AppendVarint(b, wiretag)
|
|
b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
|
|
keys := mapv.MapKeys()
|
|
sort.Slice(keys, func(i, j int) bool {
|
|
switch keys[i].Kind() {
|
|
case reflect.Bool:
|
|
return !keys[i].Bool() && keys[j].Bool()
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return keys[i].Int() < keys[j].Int()
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
return keys[i].Uint() < keys[j].Uint()
|
|
case reflect.Float32, reflect.Float64:
|
|
return keys[i].Float() < keys[j].Float()
|
|
case reflect.String:
|
|
return keys[i].String() < keys[j].String()
|
|
default:
|
|
panic("invalid kind: " + keys[i].Kind().String())
|
|
}
|
|
})
|
|
for _, key := range keys {
|
|
var err error
|
|
b = wire.AppendVarint(b, wiretag)
|
|
b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
|
|
if mi := mapi.valMessageInfo; mi != nil {
|
|
mi.init()
|
|
if !mi.needsInitCheck {
|
|
return nil
|
|
}
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
val := pointerOfValue(iter.Value())
|
|
if err := mi.isInitializedPointer(val); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
val := mapi.conv.valConv.PBValueOf(iter.Value())
|
|
if err := mapi.valFuncs.isInit(val); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|