protobuf-go/internal/impl/codec_map.go
Damien Neil 4b3a82f6b1 internal/impl: clean up Value codecs
Remove the Go type from consideration when creating Value codecs, as it
is unnecessary. Value codecs convert between wire form and Values,
while Converters convert between Values and the Go type.

Change-Id: Iaa4bc7db81ad0a29dabd42c2229e6f33a0c91c67
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/193457
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
2019-09-05 16:31:05 +00:00

177 lines
4.4 KiB
Go

// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/mapsort"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
type mapInfo struct {
goType reflect.Type
keyWiretag uint64
valWiretag uint64
keyFuncs valueCoderFuncs
valFuncs valueCoderFuncs
keyZero pref.Value
keyKind pref.Kind
}
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
// TODO: Consider generating specialized map coders.
keyField := fd.MapKey()
valField := fd.MapValue()
keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
keyFuncs := encoderFuncsForValue(keyField)
valFuncs := encoderFuncsForValue(valField)
conv := NewConverter(ft, fd)
mapi := &mapInfo{
goType: ft,
keyWiretag: keyWiretag,
valWiretag: valWiretag,
keyFuncs: keyFuncs,
valFuncs: valFuncs,
keyZero: keyField.Default(),
keyKind: keyField.Kind(),
}
funcs = pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
return sizeMap(mapv, tagsize, mapi, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
return appendMap(b, mapv, wiretag, mapi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
mp := p.AsValueOf(ft)
if mp.Elem().IsNil() {
mp.Elem().Set(reflect.MakeMap(mapi.goType))
}
mapv := conv.PBValueOf(mp.Elem()).Map()
return consumeMap(b, mapv, wtyp, mapi, opts)
},
}
if valFuncs.isInit != nil {
funcs.isInit = func(p pointer) error {
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
return isInitMap(mapv, mapi)
}
}
return funcs
}
const (
mapKeyTagSize = 1 // field 1, tag size 1.
mapValTagSize = 1 // field 2, tag size 2.
)
func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
if mapv.Len() == 0 {
return 0
}
n := 0
mapv.Range(func(key pref.MapKey, value pref.Value) bool {
n += tagsize + wire.SizeBytes(
mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
mapi.valFuncs.size(value, mapValTagSize, opts))
return true
})
return n
}
func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
b, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
var (
key = mapi.keyZero
val = mapv.NewValue()
)
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
}
b = b[n:]
err := errUnknown
switch num {
case 1:
var v pref.Value
v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
case 2:
var v pref.Value
v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
if err != nil {
break
}
val = v
}
if err == errUnknown {
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
}
} else if err != nil {
return 0, err
}
b = b[n:]
}
mapv.Set(key.MapKey(), val)
return n, nil
}
func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
if mapv.Len() == 0 {
return b, nil
}
var err error
fn := func(key pref.MapKey, value pref.Value) bool {
b = wire.AppendVarint(b, wiretag)
size := 0
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapi.valFuncs.size(value, mapValTagSize, opts)
b = wire.AppendVarint(b, uint64(size))
b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
return false
}
b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
if err != nil {
return false
}
return true
}
if opts.Deterministic() {
mapsort.Range(mapv, mapi.keyKind, fn)
} else {
mapv.Range(fn)
}
return b, err
}
func isInitMap(mapv pref.Map, mapi *mapInfo) error {
var err error
mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
err = mapi.valFuncs.isInit(value)
return err == nil
})
return err
}