// Copyright 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package impl import ( "reflect" "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/mapsort" pref "google.golang.org/protobuf/reflect/protoreflect" ) type mapInfo struct { goType reflect.Type keyWiretag uint64 valWiretag uint64 keyFuncs valueCoderFuncs valFuncs valueCoderFuncs keyZero pref.Value keyKind pref.Kind } func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) { // TODO: Consider generating specialized map coders. keyField := fd.MapKey() valField := fd.MapValue() keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()]) valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()]) keyFuncs := encoderFuncsForValue(keyField) valFuncs := encoderFuncsForValue(valField) conv := NewConverter(ft, fd) mapi := &mapInfo{ goType: ft, keyWiretag: keyWiretag, valWiretag: valWiretag, keyFuncs: keyFuncs, valFuncs: valFuncs, keyZero: keyField.Default(), keyKind: keyField.Kind(), } funcs = pointerCoderFuncs{ size: func(p pointer, tagsize int, opts marshalOptions) int { mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map() return sizeMap(mapv, tagsize, mapi, opts) }, marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) { mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map() return appendMap(b, mapv, wiretag, mapi, opts) }, unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) { mp := p.AsValueOf(ft) if mp.Elem().IsNil() { mp.Elem().Set(reflect.MakeMap(mapi.goType)) } mapv := conv.PBValueOf(mp.Elem()).Map() return consumeMap(b, mapv, wtyp, mapi, opts) }, } if valFuncs.isInit != nil { funcs.isInit = func(p pointer) error { mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map() return isInitMap(mapv, mapi) } } return funcs } const ( mapKeyTagSize = 1 // field 1, tag size 1. mapValTagSize = 1 // field 2, tag size 2. ) func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int { if mapv.Len() == 0 { return 0 } n := 0 mapv.Range(func(key pref.MapKey, value pref.Value) bool { n += tagsize + wire.SizeBytes( mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+ mapi.valFuncs.size(value, mapValTagSize, opts)) return true }) return n } func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) { if wtyp != wire.BytesType { return 0, errUnknown } b, n := wire.ConsumeBytes(b) if n < 0 { return 0, wire.ParseError(n) } var ( key = mapi.keyZero val = mapv.NewValue() ) for len(b) > 0 { num, wtyp, n := wire.ConsumeTag(b) if n < 0 { return 0, wire.ParseError(n) } b = b[n:] err := errUnknown switch num { case 1: var v pref.Value v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) if err != nil { break } key = v case 2: var v pref.Value v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts) if err != nil { break } val = v } if err == errUnknown { n = wire.ConsumeFieldValue(num, wtyp, b) if n < 0 { return 0, wire.ParseError(n) } } else if err != nil { return 0, err } b = b[n:] } mapv.Set(key.MapKey(), val) return n, nil } func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) { if mapv.Len() == 0 { return b, nil } var err error fn := func(key pref.MapKey, value pref.Value) bool { b = wire.AppendVarint(b, wiretag) size := 0 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) size += mapi.valFuncs.size(value, mapValTagSize, opts) b = wire.AppendVarint(b, uint64(size)) b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) if err != nil { return false } b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts) if err != nil { return false } return true } if opts.Deterministic() { mapsort.Range(mapv, mapi.keyKind, fn) } else { mapv.Range(fn) } return b, err } func isInitMap(mapv pref.Map, mapi *mapInfo) error { var err error mapv.Range(func(_ pref.MapKey, value pref.Value) bool { err = mapi.valFuncs.isInit(value) return err == nil }) return err }