mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-02-04 03:39:48 +00:00
b5523e32b3
When proto.Unmarshal fails, it is almost always due to being passed input which is not a wire-format message. (A text-format message, a message with framing left intact, and so forth.) An error like "variable length integer overflow" can be confusing in this case, since it implies the problem is the varint rather than the input being entirely wrong. Replace all Unmarshal parse errors with "cannot parse invalid wire-format data". Change-Id: Id97253bd39ac604e569df71778194f37b3c86c28 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/244297 Reviewed-by: Joe Tsai <joetsai@google.com>
389 lines
10 KiB
Go
389 lines
10 KiB
Go
// Copyright 2019 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package impl
|
|
|
|
import (
|
|
"reflect"
|
|
"sort"
|
|
|
|
"google.golang.org/protobuf/encoding/protowire"
|
|
"google.golang.org/protobuf/internal/genid"
|
|
pref "google.golang.org/protobuf/reflect/protoreflect"
|
|
)
|
|
|
|
type mapInfo struct {
|
|
goType reflect.Type
|
|
keyWiretag uint64
|
|
valWiretag uint64
|
|
keyFuncs valueCoderFuncs
|
|
valFuncs valueCoderFuncs
|
|
keyZero pref.Value
|
|
keyKind pref.Kind
|
|
conv *mapConverter
|
|
}
|
|
|
|
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
|
|
// TODO: Consider generating specialized map coders.
|
|
keyField := fd.MapKey()
|
|
valField := fd.MapValue()
|
|
keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
|
|
valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
|
|
keyFuncs := encoderFuncsForValue(keyField)
|
|
valFuncs := encoderFuncsForValue(valField)
|
|
conv := newMapConverter(ft, fd)
|
|
|
|
mapi := &mapInfo{
|
|
goType: ft,
|
|
keyWiretag: keyWiretag,
|
|
valWiretag: valWiretag,
|
|
keyFuncs: keyFuncs,
|
|
valFuncs: valFuncs,
|
|
keyZero: keyField.Default(),
|
|
keyKind: keyField.Kind(),
|
|
conv: conv,
|
|
}
|
|
if valField.Kind() == pref.MessageKind {
|
|
valueMessage = getMessageInfo(ft.Elem())
|
|
}
|
|
|
|
funcs = pointerCoderFuncs{
|
|
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
|
|
return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
|
|
},
|
|
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
|
|
},
|
|
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
|
|
mp := p.AsValueOf(ft)
|
|
if mp.Elem().IsNil() {
|
|
mp.Elem().Set(reflect.MakeMap(mapi.goType))
|
|
}
|
|
if f.mi == nil {
|
|
return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
|
|
} else {
|
|
return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
|
|
}
|
|
},
|
|
}
|
|
switch valField.Kind() {
|
|
case pref.MessageKind:
|
|
funcs.merge = mergeMapOfMessage
|
|
case pref.BytesKind:
|
|
funcs.merge = mergeMapOfBytes
|
|
default:
|
|
funcs.merge = mergeMap
|
|
}
|
|
if valFuncs.isInit != nil {
|
|
funcs.isInit = func(p pointer, f *coderFieldInfo) error {
|
|
return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
|
|
}
|
|
}
|
|
return valueMessage, funcs
|
|
}
|
|
|
|
const (
|
|
mapKeyTagSize = 1 // field 1, tag size 1.
|
|
mapValTagSize = 1 // field 2, tag size 2.
|
|
)
|
|
|
|
func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
|
|
if mapv.Len() == 0 {
|
|
return 0
|
|
}
|
|
n := 0
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
|
|
keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
var valSize int
|
|
value := mapi.conv.valConv.PBValueOf(iter.Value())
|
|
if f.mi == nil {
|
|
valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
|
|
} else {
|
|
p := pointerOfValue(iter.Value())
|
|
valSize += mapValTagSize
|
|
valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
|
|
}
|
|
n += f.tagsize + protowire.SizeBytes(keySize+valSize)
|
|
}
|
|
return n
|
|
}
|
|
|
|
func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
if wtyp != protowire.BytesType {
|
|
return out, errUnknown
|
|
}
|
|
b, n := protowire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
var (
|
|
key = mapi.keyZero
|
|
val = mapi.conv.valConv.New()
|
|
)
|
|
for len(b) > 0 {
|
|
num, wtyp, n := protowire.ConsumeTag(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
if num > protowire.MaxValidNumber {
|
|
return out, errDecode
|
|
}
|
|
b = b[n:]
|
|
err := errUnknown
|
|
switch num {
|
|
case genid.MapEntry_Key_field_number:
|
|
var v pref.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
key = v
|
|
n = o.n
|
|
case genid.MapEntry_Value_field_number:
|
|
var v pref.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
val = v
|
|
n = o.n
|
|
}
|
|
if err == errUnknown {
|
|
n = protowire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
} else if err != nil {
|
|
return out, err
|
|
}
|
|
b = b[n:]
|
|
}
|
|
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
|
|
out.n = n
|
|
return out, nil
|
|
}
|
|
|
|
func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
|
|
if wtyp != protowire.BytesType {
|
|
return out, errUnknown
|
|
}
|
|
b, n := protowire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
var (
|
|
key = mapi.keyZero
|
|
val = reflect.New(f.mi.GoReflectType.Elem())
|
|
)
|
|
for len(b) > 0 {
|
|
num, wtyp, n := protowire.ConsumeTag(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
if num > protowire.MaxValidNumber {
|
|
return out, errDecode
|
|
}
|
|
b = b[n:]
|
|
err := errUnknown
|
|
switch num {
|
|
case 1:
|
|
var v pref.Value
|
|
var o unmarshalOutput
|
|
v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
|
|
if err != nil {
|
|
break
|
|
}
|
|
key = v
|
|
n = o.n
|
|
case 2:
|
|
if wtyp != protowire.BytesType {
|
|
break
|
|
}
|
|
var v []byte
|
|
v, n = protowire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
var o unmarshalOutput
|
|
o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
|
|
if o.initialized {
|
|
// Consider this map item initialized so long as we see
|
|
// an initialized value.
|
|
out.initialized = true
|
|
}
|
|
}
|
|
if err == errUnknown {
|
|
n = protowire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, errDecode
|
|
}
|
|
} else if err != nil {
|
|
return out, err
|
|
}
|
|
b = b[n:]
|
|
}
|
|
mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
|
|
out.n = n
|
|
return out, nil
|
|
}
|
|
|
|
func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
if f.mi == nil {
|
|
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
|
|
val := mapi.conv.valConv.PBValueOf(valrv)
|
|
size := 0
|
|
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
size += mapi.valFuncs.size(val, mapValTagSize, opts)
|
|
b = protowire.AppendVarint(b, uint64(size))
|
|
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
|
|
} else {
|
|
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
|
|
val := pointerOfValue(valrv)
|
|
valSize := f.mi.sizePointer(val, opts)
|
|
size := 0
|
|
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
size += mapValTagSize + protowire.SizeBytes(valSize)
|
|
b = protowire.AppendVarint(b, uint64(size))
|
|
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
b = protowire.AppendVarint(b, mapi.valWiretag)
|
|
b = protowire.AppendVarint(b, uint64(valSize))
|
|
return f.mi.marshalAppendPointer(b, val, opts)
|
|
}
|
|
}
|
|
|
|
func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
if mapv.Len() == 0 {
|
|
return b, nil
|
|
}
|
|
if opts.Deterministic() {
|
|
return appendMapDeterministic(b, mapv, mapi, f, opts)
|
|
}
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
var err error
|
|
b = protowire.AppendVarint(b, f.wiretag)
|
|
b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
|
|
keys := mapv.MapKeys()
|
|
sort.Slice(keys, func(i, j int) bool {
|
|
switch keys[i].Kind() {
|
|
case reflect.Bool:
|
|
return !keys[i].Bool() && keys[j].Bool()
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return keys[i].Int() < keys[j].Int()
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
return keys[i].Uint() < keys[j].Uint()
|
|
case reflect.Float32, reflect.Float64:
|
|
return keys[i].Float() < keys[j].Float()
|
|
case reflect.String:
|
|
return keys[i].String() < keys[j].String()
|
|
default:
|
|
panic("invalid kind: " + keys[i].Kind().String())
|
|
}
|
|
})
|
|
for _, key := range keys {
|
|
var err error
|
|
b = protowire.AppendVarint(b, f.wiretag)
|
|
b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
|
|
if err != nil {
|
|
return b, err
|
|
}
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
|
|
if mi := f.mi; mi != nil {
|
|
mi.init()
|
|
if !mi.needsInitCheck {
|
|
return nil
|
|
}
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
val := pointerOfValue(iter.Value())
|
|
if err := mi.checkInitializedPointer(val); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
iter := mapRange(mapv)
|
|
for iter.Next() {
|
|
val := mapi.conv.valConv.PBValueOf(iter.Value())
|
|
if err := mapi.valFuncs.isInit(val); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
|
|
dstm := dst.AsValueOf(f.ft).Elem()
|
|
srcm := src.AsValueOf(f.ft).Elem()
|
|
if srcm.Len() == 0 {
|
|
return
|
|
}
|
|
if dstm.IsNil() {
|
|
dstm.Set(reflect.MakeMap(f.ft))
|
|
}
|
|
iter := mapRange(srcm)
|
|
for iter.Next() {
|
|
dstm.SetMapIndex(iter.Key(), iter.Value())
|
|
}
|
|
}
|
|
|
|
func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
|
|
dstm := dst.AsValueOf(f.ft).Elem()
|
|
srcm := src.AsValueOf(f.ft).Elem()
|
|
if srcm.Len() == 0 {
|
|
return
|
|
}
|
|
if dstm.IsNil() {
|
|
dstm.Set(reflect.MakeMap(f.ft))
|
|
}
|
|
iter := mapRange(srcm)
|
|
for iter.Next() {
|
|
dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
|
|
}
|
|
}
|
|
|
|
func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
|
|
dstm := dst.AsValueOf(f.ft).Elem()
|
|
srcm := src.AsValueOf(f.ft).Elem()
|
|
if srcm.Len() == 0 {
|
|
return
|
|
}
|
|
if dstm.IsNil() {
|
|
dstm.Set(reflect.MakeMap(f.ft))
|
|
}
|
|
iter := mapRange(srcm)
|
|
for iter.Next() {
|
|
val := reflect.New(f.ft.Elem().Elem())
|
|
if f.mi != nil {
|
|
f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
|
|
} else {
|
|
opts.Merge(asMessage(val), asMessage(iter.Value()))
|
|
}
|
|
dstm.SetMapIndex(iter.Key(), val)
|
|
}
|
|
}
|