mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-01-17 01:12:51 +00:00
94bb78c93b
There currently is no risk of producing invalid wire format, but that will change with subsequent changes regarding lazy decoding. We have been running this change in production for about a month, without ever triggering the check (until lazy decoding is involved). related to golang/protobuf#1609 Change-Id: I3c5c956aee2fa81f99dea03ed2a977a1547081fc Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579595 Auto-Submit: Michael Stapelberg <stapelberg@google.com> Reviewed-by: Lasse Folger <lassefolger@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
400 lines
11 KiB
Go
400 lines
11 KiB
Go
// Copyright 2019 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package impl
|
|
|
|
import (
|
|
"reflect"
|
|
"sort"
|
|
|
|
"google.golang.org/protobuf/encoding/protowire"
|
|
"google.golang.org/protobuf/internal/errors"
|
|
"google.golang.org/protobuf/internal/genid"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
)
|
|
|
|
type mapInfo struct {
|
|
goType reflect.Type
|
|
keyWiretag uint64
|
|
valWiretag uint64
|
|
keyFuncs valueCoderFuncs
|
|
valFuncs valueCoderFuncs
|
|
keyZero protoreflect.Value
|
|
keyKind protoreflect.Kind
|
|
conv *mapConverter
|
|
}
|
|
|
|
func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
|
|
// TODO: Consider generating specialized map coders.
|
|
keyField := fd.MapKey()
|
|
valField := fd.MapValue()
|
|
keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
|
|
valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
|
|
keyFuncs := encoderFuncsForValue(keyField)
|
|
valFuncs := encoderFuncsForValue(valField)
|
|
conv := newMapConverter(ft, fd)
|
|
|
|
mapi := &mapInfo{
|
|
goType: ft,
|
|
keyWiretag: keyWiretag,
|
|
valWiretag: valWiretag,
|
|
keyFuncs: keyFuncs,
|
|
valFuncs: valFuncs,
|
|
keyZero: keyField.Default(),
|
|
keyKind: keyField.Kind(),
|
|
conv: conv,
|
|
}
|
|
if valField.Kind() == protoreflect.MessageKind {
|
|
valueMessage = getMessageInfo(ft.Elem())
|
|
}
|
|
|
|
funcs = pointerCoderFuncs{
|
|
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
|
|
return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
|
|
},
|
|
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
|
|
},
|
|
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
|
|
mp := p.AsValueOf(ft)
|
|
if mp.Elem().IsNil() {
|
|
mp.Elem().Set(reflect.MakeMap(mapi.goType))
|
|
}
|
|
if f.mi == nil {
|
|
return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
|
|
} else {
|
|
return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
|
|
}
|
|
},
|
|
}
|
|
switch valField.Kind() {
|
|
case protoreflect.MessageKind:
|
|
funcs.merge = mergeMapOfMessage
|
|
case protoreflect.BytesKind:
|
|
funcs.merge = mergeMapOfBytes
|
|
default:
|
|
funcs.merge = mergeMap
|
|
}
|
|
if valFuncs.isInit != nil {
|
|
funcs.isInit = func(p pointer, f *coderFieldInfo) error {
|
|
return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
|
|
}
|
|
}
|
|
return valueMessage, funcs
|
|
}
|
|
|
|
const (
|
|
mapKeyTagSize = 1 // field 1, tag size 1.
|
|
mapValTagSize = 1 // field 2, tag size 2.
|
|
)
|
|
|
|
func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
|
|
if mapv.Len() == 0 {
|
|
return 0
|
|
}
|
|
n := 0
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
|
|
keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
var valSize int
|
|
value := mapi.conv.valConv.PBValueOf(iter.Value())
|
|
if f.mi == nil {
|
|
valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
|
|
} else {
|
|
p := pointerOfValue(iter.Value())
|
|
valSize += mapValTagSize
|
|
valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
|
|
}
|
|
n += f.tagsize + protowire.SizeBytes(keySize+valSize)
|
|
}
|
|
return n
|
|
}
|
|
|
|
func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
if wtyp != protowire.BytesType {
|
|
return out, errUnknown
|
|
}
|
|
b, n := protowire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
var (
|
|
key = mapi.keyZero
|
|
val = mapi.conv.valConv.New()
|
|
)
|
|
for len(b) > 0 {
|
|
num, wtyp, n := protowire.ConsumeTag(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
if num > protowire.MaxValidNumber {
|
|
return out, errDecode
|
|
}
|
|
b = b[n:]
|
|
err := errUnknown
|
|
switch num {
|
|
case genid.MapEntry_Key_field_number:
|
|
var v protoreflect.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
key = v
|
|
n = o.n
|
|
case genid.MapEntry_Value_field_number:
|
|
var v protoreflect.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
val = v
|
|
n = o.n
|
|
}
|
|
if err == errUnknown {
|
|
n = protowire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
} else if err != nil {
|
|
return out, err
|
|
}
|
|
b = b[n:]
|
|
}
|
|
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
|
|
out.n = n
|
|
return out, nil
|
|
}
|
|
|
|
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
if wtyp != protowire.BytesType {
|
|
return out, errUnknown
|
|
}
|
|
b, n := protowire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
var (
|
|
key = mapi.keyZero
|
|
val = reflect.New(f.mi.GoReflectType.Elem())
|
|
)
|
|
for len(b) > 0 {
|
|
num, wtyp, n := protowire.ConsumeTag(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
if num > protowire.MaxValidNumber {
|
|
return out, errDecode
|
|
}
|
|
b = b[n:]
|
|
err := errUnknown
|
|
switch num {
|
|
case 1:
|
|
var v protoreflect.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
key = v
|
|
n = o.n
|
|
case 2:
|
|
if wtyp != protowire.BytesType {
|
|
break
|
|
}
|
|
var v []byte
|
|
v, n = protowire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
var o unmarshalOutput
|
|
o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
|
|
if o.initialized {
|
|
// Consider this map item initialized so long as we see
|
|
// an initialized value.
|
|
out.initialized = true
|
|
}
|
|
}
|
|
if err == errUnknown {
|
|
n = protowire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
} else if err != nil {
|
|
return out, err
|
|
}
|
|
b = b[n:]
|
|
}
|
|
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
|
|
out.n = n
|
|
return out, nil
|
|
}
|
|
|
|
func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
if f.mi == nil {
|
|
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
|
|
val := mapi.conv.valConv.PBValueOf(valrv)
|
|
size := 0
|
|
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
size += mapi.valFuncs.size(val, mapValTagSize, opts)
|
|
b = protowire.AppendVarint(b, uint64(size))
|
|
before := len(b)
|
|
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
|
|
if measuredSize := len(b) - before; size != measuredSize && err == nil {
|
|
return nil, errors.MismatchedSizeCalculation(size, measuredSize)
|
|
}
|
|
return b, err
|
|
} else {
|
|
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
|
|
val := pointerOfValue(valrv)
|
|
valSize := f.mi.sizePointer(val, opts)
|
|
size := 0
|
|
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
size += mapValTagSize + protowire.SizeBytes(valSize)
|
|
b = protowire.AppendVarint(b, uint64(size))
|
|
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b = protowire.AppendVarint(b, mapi.valWiretag)
|
|
b = protowire.AppendVarint(b, uint64(valSize))
|
|
before := len(b)
|
|
b, err = f.mi.marshalAppendPointer(b, val, opts)
|
|
if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
|
|
return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
|
|
}
|
|
return b, err
|
|
}
|
|
}
|
|
|
|
func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
if mapv.Len() == 0 {
|
|
return b, nil
|
|
}
|
|
if opts.Deterministic() {
|
|
return appendMapDeterministic(b, mapv, mapi, f, opts)
|
|
}
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
var err error
|
|
b = protowire.AppendVarint(b, f.wiretag)
|
|
b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
keys := mapv.MapKeys()
|
|
sort.Slice(keys, func(i, j int) bool {
|
|
switch keys[i].Kind() {
|
|
case reflect.Bool:
|
|
return !keys[i].Bool() && keys[j].Bool()
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return keys[i].Int() < keys[j].Int()
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
return keys[i].Uint() < keys[j].Uint()
|
|
case reflect.Float32, reflect.Float64:
|
|
return keys[i].Float() < keys[j].Float()
|
|
case reflect.String:
|
|
return keys[i].String() < keys[j].String()
|
|
default:
|
|
panic("invalid kind: " + keys[i].Kind().String())
|
|
}
|
|
})
|
|
for _, key := range keys {
|
|
var err error
|
|
b = protowire.AppendVarint(b, f.wiretag)
|
|
b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
|
|
if mi := f.mi; mi != nil {
|
|
mi.init()
|
|
if !mi.needsInitCheck {
|
|
return nil
|
|
}
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
val := pointerOfValue(iter.Value())
|
|
if err := mi.checkInitializedPointer(val); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
val := mapi.conv.valConv.PBValueOf(iter.Value())
|
|
if err := mapi.valFuncs.isInit(val); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
|
|
dstm := dst.AsValueOf(f.ft).Elem()
|
|
srcm := src.AsValueOf(f.ft).Elem()
|
|
if srcm.Len() == 0 {
|
|
return
|
|
}
|
|
if dstm.IsNil() {
|
|
dstm.Set(reflect.MakeMap(f.ft))
|
|
}
|
|
iter := mapRange(srcm)
|
|
for iter.Next() {
|
|
dstm.SetMapIndex(iter.Key(), iter.Value())
|
|
}
|
|
}
|
|
|
|
func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
|
|
dstm := dst.AsValueOf(f.ft).Elem()
|
|
srcm := src.AsValueOf(f.ft).Elem()
|
|
if srcm.Len() == 0 {
|
|
return
|
|
}
|
|
if dstm.IsNil() {
|
|
dstm.Set(reflect.MakeMap(f.ft))
|
|
}
|
|
iter := mapRange(srcm)
|
|
for iter.Next() {
|
|
dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
|
|
}
|
|
}
|
|
|
|
func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
|
|
dstm := dst.AsValueOf(f.ft).Elem()
|
|
srcm := src.AsValueOf(f.ft).Elem()
|
|
if srcm.Len() == 0 {
|
|
return
|
|
}
|
|
if dstm.IsNil() {
|
|
dstm.Set(reflect.MakeMap(f.ft))
|
|
}
|
|
iter := mapRange(srcm)
|
|
for iter.Next() {
|
|
val := reflect.New(f.ft.Elem().Elem())
|
|
if f.mi != nil {
|
|
f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
|
|
} else {
|
|
opts.Merge(asMessage(val), asMessage(iter.Value()))
|
|
}
|
|
dstm.SetMapIndex(iter.Key(), val)
|
|
}
|
|
}
|