2019-04-01 13:49:56 -07:00
|
|
|
// Copyright 2019 The Go Authors. All rights reserved.
|
|
|
|
// Use of this source code is governed by a BSD-style
|
|
|
|
// license that can be found in the LICENSE file.
|
|
|
|
|
|
|
|
package impl
|
|
|
|
|
|
|
|
import (
|
|
|
|
"reflect"
|
|
|
|
|
|
|
|
"google.golang.org/protobuf/internal/encoding/wire"
|
2019-08-22 11:41:32 -07:00
|
|
|
"google.golang.org/protobuf/internal/mapsort"
|
2019-04-01 13:49:56 -07:00
|
|
|
pref "google.golang.org/protobuf/reflect/protoreflect"
|
|
|
|
)
|
|
|
|
|
2019-06-27 10:54:42 -07:00
|
|
|
type mapInfo struct {
|
|
|
|
goType reflect.Type
|
|
|
|
keyWiretag uint64
|
|
|
|
valWiretag uint64
|
2019-08-22 11:41:32 -07:00
|
|
|
keyFuncs valueCoderFuncs
|
|
|
|
valFuncs valueCoderFuncs
|
|
|
|
keyZero pref.Value
|
|
|
|
keyKind pref.Kind
|
2019-06-27 10:54:42 -07:00
|
|
|
}
|
|
|
|
|
2019-04-01 13:49:56 -07:00
|
|
|
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
|
|
|
|
// TODO: Consider generating specialized map coders.
|
|
|
|
keyField := fd.MapKey()
|
|
|
|
valField := fd.MapValue()
|
|
|
|
keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
|
|
|
|
valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
|
|
|
|
keyFuncs := encoderFuncsForValue(keyField, ft.Key())
|
|
|
|
valFuncs := encoderFuncsForValue(valField, ft.Elem())
|
2019-08-22 11:41:32 -07:00
|
|
|
conv := NewConverter(ft, fd)
|
2019-04-01 13:49:56 -07:00
|
|
|
|
2019-06-27 10:54:42 -07:00
|
|
|
mapi := &mapInfo{
|
|
|
|
goType: ft,
|
|
|
|
keyWiretag: keyWiretag,
|
|
|
|
valWiretag: valWiretag,
|
|
|
|
keyFuncs: keyFuncs,
|
|
|
|
valFuncs: valFuncs,
|
2019-08-22 11:41:32 -07:00
|
|
|
keyZero: keyField.Default(),
|
|
|
|
keyKind: keyField.Kind(),
|
2019-06-27 10:54:42 -07:00
|
|
|
}
|
|
|
|
|
2019-04-09 15:57:05 -07:00
|
|
|
funcs = pointerCoderFuncs{
|
2019-04-01 13:49:56 -07:00
|
|
|
size: func(p pointer, tagsize int, opts marshalOptions) int {
|
2019-08-22 11:41:32 -07:00
|
|
|
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
|
|
|
|
return sizeMap(mapv, tagsize, mapi, opts)
|
2019-04-01 13:49:56 -07:00
|
|
|
},
|
|
|
|
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
|
2019-08-22 11:41:32 -07:00
|
|
|
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
|
|
|
|
return appendMap(b, mapv, wiretag, mapi, opts)
|
2019-04-01 13:49:56 -07:00
|
|
|
},
|
2019-06-27 10:54:42 -07:00
|
|
|
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
|
2019-08-22 11:41:32 -07:00
|
|
|
mp := p.AsValueOf(ft)
|
|
|
|
if mp.Elem().IsNil() {
|
|
|
|
mp.Elem().Set(reflect.MakeMap(mapi.goType))
|
|
|
|
}
|
|
|
|
mapv := conv.PBValueOf(mp.Elem()).Map()
|
|
|
|
return consumeMap(b, mapv, wtyp, mapi, opts)
|
2019-06-27 10:54:42 -07:00
|
|
|
},
|
2019-04-01 13:49:56 -07:00
|
|
|
}
|
2019-04-09 15:57:05 -07:00
|
|
|
if valFuncs.isInit != nil {
|
|
|
|
funcs.isInit = func(p pointer) error {
|
2019-08-22 11:41:32 -07:00
|
|
|
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
|
|
|
|
return isInitMap(mapv, mapi)
|
2019-04-09 15:57:05 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return funcs
|
2019-04-01 13:49:56 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
const (
|
|
|
|
mapKeyTagSize = 1 // field 1, tag size 1.
|
|
|
|
mapValTagSize = 1 // field 2, tag size 2.
|
|
|
|
)
|
|
|
|
|
2019-08-22 11:41:32 -07:00
|
|
|
func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
|
|
|
|
if mapv.Len() == 0 {
|
|
|
|
return 0
|
2019-06-27 10:54:42 -07:00
|
|
|
}
|
2019-08-22 11:41:32 -07:00
|
|
|
n := 0
|
|
|
|
mapv.Range(func(key pref.MapKey, value pref.Value) bool {
|
|
|
|
n += tagsize + wire.SizeBytes(
|
|
|
|
mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
|
|
|
|
mapi.valFuncs.size(value, mapValTagSize, opts))
|
|
|
|
return true
|
|
|
|
})
|
|
|
|
return n
|
|
|
|
}
|
2019-06-27 10:54:42 -07:00
|
|
|
|
2019-08-22 11:41:32 -07:00
|
|
|
func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
|
2019-06-27 10:54:42 -07:00
|
|
|
if wtyp != wire.BytesType {
|
|
|
|
return 0, errUnknown
|
|
|
|
}
|
|
|
|
b, n := wire.ConsumeBytes(b)
|
|
|
|
if n < 0 {
|
|
|
|
return 0, wire.ParseError(n)
|
|
|
|
}
|
|
|
|
var (
|
|
|
|
key = mapi.keyZero
|
2019-08-22 11:41:32 -07:00
|
|
|
val = mapv.NewValue()
|
2019-06-27 10:54:42 -07:00
|
|
|
)
|
|
|
|
for len(b) > 0 {
|
|
|
|
num, wtyp, n := wire.ConsumeTag(b)
|
|
|
|
if n < 0 {
|
|
|
|
return 0, wire.ParseError(n)
|
|
|
|
}
|
|
|
|
b = b[n:]
|
|
|
|
err := errUnknown
|
|
|
|
switch num {
|
|
|
|
case 1:
|
2019-08-22 11:41:32 -07:00
|
|
|
var v pref.Value
|
2019-06-27 10:54:42 -07:00
|
|
|
v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
|
|
|
|
if err != nil {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
key = v
|
|
|
|
case 2:
|
2019-08-22 11:41:32 -07:00
|
|
|
var v pref.Value
|
2019-06-27 10:54:42 -07:00
|
|
|
v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
|
|
|
|
if err != nil {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
val = v
|
|
|
|
}
|
|
|
|
if err == errUnknown {
|
|
|
|
n = wire.ConsumeFieldValue(num, wtyp, b)
|
|
|
|
if n < 0 {
|
|
|
|
return 0, wire.ParseError(n)
|
|
|
|
}
|
|
|
|
} else if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
b = b[n:]
|
|
|
|
}
|
2019-08-22 11:41:32 -07:00
|
|
|
mapv.Set(key.MapKey(), val)
|
2019-06-27 10:54:42 -07:00
|
|
|
return n, nil
|
|
|
|
}
|
|
|
|
|
2019-08-22 11:41:32 -07:00
|
|
|
func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
|
|
|
|
if mapv.Len() == 0 {
|
2019-06-19 09:28:29 -07:00
|
|
|
return b, nil
|
2019-04-01 13:49:56 -07:00
|
|
|
}
|
2019-08-22 11:41:32 -07:00
|
|
|
var err error
|
|
|
|
fn := func(key pref.MapKey, value pref.Value) bool {
|
|
|
|
b = wire.AppendVarint(b, wiretag)
|
|
|
|
size := 0
|
|
|
|
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
|
|
|
|
size += mapi.valFuncs.size(value, mapValTagSize, opts)
|
|
|
|
b = wire.AppendVarint(b, uint64(size))
|
|
|
|
b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
|
2019-06-19 09:28:29 -07:00
|
|
|
if err != nil {
|
2019-08-22 11:41:32 -07:00
|
|
|
return false
|
2019-04-01 13:49:56 -07:00
|
|
|
}
|
2019-08-22 11:41:32 -07:00
|
|
|
b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
|
|
|
|
if err != nil {
|
|
|
|
return false
|
2019-04-09 15:57:05 -07:00
|
|
|
}
|
2019-08-22 11:41:32 -07:00
|
|
|
return true
|
2019-04-09 15:57:05 -07:00
|
|
|
}
|
2019-08-22 11:41:32 -07:00
|
|
|
if opts.Deterministic() {
|
|
|
|
mapsort.Range(mapv, mapi.keyKind, fn)
|
|
|
|
} else {
|
|
|
|
mapv.Range(fn)
|
2019-04-01 13:49:56 -07:00
|
|
|
}
|
2019-08-22 11:41:32 -07:00
|
|
|
return b, err
|
2019-04-01 13:49:56 -07:00
|
|
|
}
|
|
|
|
|
2019-08-22 11:41:32 -07:00
|
|
|
func isInitMap(mapv pref.Map, mapi *mapInfo) error {
|
|
|
|
var err error
|
|
|
|
mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
|
|
|
|
err = mapi.valFuncs.isInit(value)
|
|
|
|
return err == nil
|
|
|
|
})
|
|
|
|
return err
|
2019-04-01 13:49:56 -07:00
|
|
|
}
|