internal/order: add a package for ordered iteration over messages and maps

The order package replaces the mapsort and fieldsort packages.
It presents a common API for ordered iteration over message fields
and map fields.

It has a number of pre-defined orderings.

Change-Id: Ie6cd423da30b4757864c352cb04454f21fe07ee2
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/239837
Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
Joe Tsai 2020-06-24 14:28:07 -07:00
parent b78321453d
commit 92679665d7
14 changed files with 532 additions and 434 deletions

View File

@ -7,13 +7,13 @@ package protojson
import (
"encoding/base64"
"fmt"
"sort"
"google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
@ -160,61 +160,71 @@ func (e encoder) marshalMessage(m pref.Message) error {
return nil
}
// unpopulatedFieldRanger wraps a protoreflect.Message and modifies its Range
// method to additionally iterate over unpopulated fields.
type unpopulatedFieldRanger struct{ pref.Message }
func (m unpopulatedFieldRanger) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
fds := m.Descriptor().Fields()
for i := 0; i < fds.Len(); i++ {
fd := fds.Get(i)
if m.Has(fd) || fd.ContainingOneof() != nil {
continue // ignore populated fields and fields within a oneofs
}
v := m.Get(fd)
isProto2Scalar := fd.Syntax() == pref.Proto2 && fd.Default().IsValid()
isSingularMessage := fd.Cardinality() != pref.Repeated && fd.Message() != nil
if isProto2Scalar || isSingularMessage {
v = pref.Value{} // use invalid value to emit null
}
if !f(fd, v) {
return
}
}
m.Message.Range(f)
}
// marshalFields marshals the fields in the given protoreflect.Message.
func (e encoder) marshalFields(m pref.Message) error {
messageDesc := m.Descriptor()
if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
if !flags.ProtoLegacy && messageset.IsMessageSet(m.Descriptor()) {
return errors.New("no support for proto1 MessageSets")
}
// Marshal out known fields.
fieldDescs := messageDesc.Fields()
for i := 0; i < fieldDescs.Len(); {
fd := fieldDescs.Get(i)
if od := fd.ContainingOneof(); od != nil {
fd = m.WhichOneof(od)
i += od.Fields().Len()
if fd == nil {
continue // unpopulated oneofs are not affected by EmitUnpopulated
}
} else {
i++
}
var fields order.FieldRanger = m
if e.opts.EmitUnpopulated {
fields = unpopulatedFieldRanger{m}
}
val := m.Get(fd)
if !m.Has(fd) {
if !e.opts.EmitUnpopulated {
continue
var err error
order.RangeFields(fields, order.IndexNameFieldOrder, func(fd pref.FieldDescriptor, v pref.Value) bool {
var name string
switch {
case fd.IsExtension():
if messageset.IsMessageSetExtension(fd) {
name = "[" + string(fd.FullName().Parent()) + "]"
} else {
name = "[" + string(fd.FullName()) + "]"
}
isProto2Scalar := fd.Syntax() == pref.Proto2 && fd.Default().IsValid()
isSingularMessage := fd.Cardinality() != pref.Repeated && fd.Message() != nil
if isProto2Scalar || isSingularMessage {
// Use invalid value to emit null.
val = pref.Value{}
}
}
name := fd.JSONName()
if e.opts.UseProtoNames {
name = string(fd.Name())
// Use type name for group field name.
case e.opts.UseProtoNames:
if fd.Kind() == pref.GroupKind {
name = string(fd.Message().Name())
} else {
name = string(fd.Name())
}
default:
name = fd.JSONName()
}
if err := e.WriteName(name); err != nil {
return err
}
if err := e.marshalValue(val, fd); err != nil {
return err
}
}
// Marshal out extensions.
if err := e.marshalExtensions(m); err != nil {
return err
}
return nil
if err = e.WriteName(name); err != nil {
return false
}
if err = e.marshalValue(v, fd); err != nil {
return false
}
return true
})
return err
}
// marshalValue marshals the given protoreflect.Value.
@ -305,98 +315,20 @@ func (e encoder) marshalList(list pref.List, fd pref.FieldDescriptor) error {
return nil
}
type mapEntry struct {
key pref.MapKey
value pref.Value
}
// marshalMap marshals given protoreflect.Map.
func (e encoder) marshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
e.StartObject()
defer e.EndObject()
// Get a sorted list based on keyType first.
entries := make([]mapEntry, 0, mmap.Len())
mmap.Range(func(key pref.MapKey, val pref.Value) bool {
entries = append(entries, mapEntry{key: key, value: val})
var err error
order.RangeEntries(mmap, order.GenericKeyOrder, func(k pref.MapKey, v pref.Value) bool {
if err = e.WriteName(k.String()); err != nil {
return false
}
if err = e.marshalSingular(v, fd.MapValue()); err != nil {
return false
}
return true
})
sortMap(fd.MapKey().Kind(), entries)
// Write out sorted list.
for _, entry := range entries {
if err := e.WriteName(entry.key.String()); err != nil {
return err
}
if err := e.marshalSingular(entry.value, fd.MapValue()); err != nil {
return err
}
}
return nil
}
// sortMap orders list based on value of key field for deterministic ordering.
func sortMap(keyKind pref.Kind, values []mapEntry) {
sort.Slice(values, func(i, j int) bool {
switch keyKind {
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
return values[i].key.Int() < values[j].key.Int()
case pref.Uint32Kind, pref.Fixed32Kind,
pref.Uint64Kind, pref.Fixed64Kind:
return values[i].key.Uint() < values[j].key.Uint()
}
return values[i].key.String() < values[j].key.String()
})
}
// marshalExtensions marshals extension fields.
func (e encoder) marshalExtensions(m pref.Message) error {
type entry struct {
key string
value pref.Value
desc pref.FieldDescriptor
}
// Get a sorted list based on field key first.
var entries []entry
m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
if !fd.IsExtension() {
return true
}
// For MessageSet extensions, the name used is the parent message.
name := fd.FullName()
if messageset.IsMessageSetExtension(fd) {
name = name.Parent()
}
// Use [name] format for JSON field name.
entries = append(entries, entry{
key: string(name),
value: v,
desc: fd,
})
return true
})
// Sort extensions lexicographically.
sort.Slice(entries, func(i, j int) bool {
return entries[i].key < entries[j].key
})
// Write out sorted list.
for _, entry := range entries {
// JSON field name is the proto field name enclosed in [], similar to
// textproto. This is consistent with Go v1 lib. C++ lib v3.7.0 does not
// marshal out extension fields.
if err := e.WriteName("[" + entry.key + "]"); err != nil {
return err
}
if err := e.marshalValue(entry.value, entry.desc); err != nil {
return err
}
}
return nil
return err
}

View File

@ -1060,12 +1060,12 @@ func TestMarshal(t *testing.T) {
return m
}(),
want: `{
"[pb2.MessageSetExtension]": {
"optString": "a messageset extension"
},
"[pb2.MessageSetExtension.ext_nested]": {
"optString": "just a regular extension"
},
"[pb2.MessageSetExtension]": {
"optString": "a messageset extension"
},
"[pb2.MessageSetExtension.not_message_set_extension]": {
"optString": "not a messageset extension"
}
@ -2123,6 +2123,35 @@ func TestMarshal(t *testing.T) {
"optNested": null
}
]
}`,
}, {
desc: "EmitUnpopulated: with populated fields",
mo: protojson.MarshalOptions{EmitUnpopulated: true},
input: &pb2.Scalars{
OptInt32: proto.Int32(0xff),
OptUint32: proto.Uint32(47),
OptSint32: proto.Int32(-1001),
OptFixed32: proto.Uint32(32),
OptSfixed32: proto.Int32(-32),
OptFloat: proto.Float32(1.02),
OptBytes: []byte("谷歌"),
},
want: `{
"optBool": null,
"optInt32": 255,
"optInt64": null,
"optUint32": 47,
"optUint64": null,
"optSint32": -1001,
"optSint64": null,
"optFixed32": 32,
"optFixed64": null,
"optSfixed32": -32,
"optSfixed64": null,
"optFloat": 1.02,
"optDouble": null,
"optBytes": "6LC35q2M",
"optString": null
}`,
}, {
desc: "UseEnumNumbers in singular field",

View File

@ -6,7 +6,6 @@ package prototext
import (
"fmt"
"sort"
"strconv"
"unicode/utf8"
@ -16,10 +15,11 @@ import (
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/strs"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
@ -169,35 +169,30 @@ func (e encoder) marshalMessage(m pref.Message, inclDelims bool) error {
// If unable to expand, continue on to marshal Any as a regular message.
}
// Marshal known fields.
fieldDescs := messageDesc.Fields()
size := fieldDescs.Len()
for i := 0; i < size; {
fd := fieldDescs.Get(i)
if od := fd.ContainingOneof(); od != nil {
fd = m.WhichOneof(od)
i += od.Fields().Len()
// Marshal fields.
var err error
order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
var name string
if fd.IsExtension() {
if messageset.IsMessageSetExtension(fd) {
name = "[" + string(fd.FullName().Parent()) + "]"
} else {
name = "[" + string(fd.FullName()) + "]"
}
} else {
i++
if fd.Kind() == pref.GroupKind {
name = string(fd.Message().Name())
} else {
name = string(fd.Name())
}
}
if fd == nil || !m.Has(fd) {
continue
if err = e.marshalField(string(name), v, fd); err != nil {
return false
}
name := fd.Name()
// Use type name for group field name.
if fd.Kind() == pref.GroupKind {
name = fd.Message().Name()
}
val := m.Get(fd)
if err := e.marshalField(string(name), val, fd); err != nil {
return err
}
}
// Marshal extensions.
if err := e.marshalExtensions(m); err != nil {
return true
})
if err != nil {
return err
}
@ -290,7 +285,7 @@ func (e encoder) marshalList(name string, list pref.List, fd pref.FieldDescripto
// marshalMap marshals the given protoreflect.Map as multiple name-value fields.
func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor) error {
var err error
mapsort.Range(mmap, fd.MapKey().Kind(), func(key pref.MapKey, val pref.Value) bool {
order.RangeEntries(mmap, order.GenericKeyOrder, func(key pref.MapKey, val pref.Value) bool {
e.WriteName(name)
e.StartMessage()
defer e.EndMessage()
@ -311,48 +306,6 @@ func (e encoder) marshalMap(name string, mmap pref.Map, fd pref.FieldDescriptor)
return err
}
// marshalExtensions marshals extension fields.
func (e encoder) marshalExtensions(m pref.Message) error {
type entry struct {
key string
value pref.Value
desc pref.FieldDescriptor
}
// Get a sorted list based on field key first.
var entries []entry
m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
if !fd.IsExtension() {
return true
}
// For MessageSet extensions, the name used is the parent message.
name := fd.FullName()
if messageset.IsMessageSetExtension(fd) {
name = name.Parent()
}
entries = append(entries, entry{
key: string(name),
value: v,
desc: fd,
})
return true
})
// Sort extensions lexicographically.
sort.Slice(entries, func(i, j int) bool {
return entries[i].key < entries[j].key
})
// Write out sorted list.
for _, entry := range entries {
// Extension field name is the proto field name enclosed in [].
name := "[" + entry.key + "]"
if err := e.marshalField(name, entry.value, entry.desc); err != nil {
return err
}
}
return nil
}
// marshalUnknown parses the given []byte and marshals fields out.
// This function assumes proper encoding in the given []byte.
func (e encoder) marshalUnknown(b []byte) {

View File

@ -1158,12 +1158,12 @@ opt_int32: 42
})
return m
}(),
want: `[pb2.MessageSetExtension]: {
opt_string: "a messageset extension"
}
[pb2.MessageSetExtension.ext_nested]: {
want: `[pb2.MessageSetExtension.ext_nested]: {
opt_string: "just a regular extension"
}
[pb2.MessageSetExtension]: {
opt_string: "a messageset extension"
}
[pb2.MessageSetExtension.not_message_set_extension]: {
opt_string: "not a messageset extension"
}

View File

@ -1,40 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fieldsort defines an ordering of fields.
//
// The ordering defined by this package matches the historic behavior of the proto
// package, placing extensions first and oneofs last.
//
// There is no guarantee about stability of the wire encoding, and users should not
// depend on the order defined in this package as it is subject to change without
// notice.
package fieldsort
import (
"google.golang.org/protobuf/reflect/protoreflect"
)
// Less returns true if field a comes before field j in ordered wire marshal output.
func Less(a, b protoreflect.FieldDescriptor) bool {
ea := a.IsExtension()
eb := b.IsExtension()
oa := a.ContainingOneof()
ob := b.ContainingOneof()
switch {
case ea != eb:
return ea
case oa != nil && ob != nil:
if oa == ob {
return a.Number() < b.Number()
}
return oa.Index() < ob.Index()
case oa != nil && !oa.IsSynthetic():
return false
case ob != nil && !ob.IsSynthetic():
return true
default:
return a.Number() < b.Number()
}
}

View File

@ -11,7 +11,7 @@ import (
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/fieldsort"
"google.golang.org/protobuf/internal/order"
pref "google.golang.org/protobuf/reflect/protoreflect"
piface "google.golang.org/protobuf/runtime/protoiface"
)
@ -136,7 +136,7 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
fi := fields.ByNumber(mi.orderedCoderFields[i].num)
fj := fields.ByNumber(mi.orderedCoderFields[j].num)
return fieldsort.Less(fi, fj)
return order.LegacyFieldOrder(fi, fj)
})
}

View File

@ -1,43 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package mapsort provides sorted access to maps.
package mapsort
import (
"sort"
"google.golang.org/protobuf/reflect/protoreflect"
)
// Range iterates over every map entry in sorted key order,
// calling f for each key and value encountered.
func Range(mapv protoreflect.Map, keyKind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
var keys []protoreflect.MapKey
mapv.Range(func(key protoreflect.MapKey, _ protoreflect.Value) bool {
keys = append(keys, key)
return true
})
sort.Slice(keys, func(i, j int) bool {
switch keyKind {
case protoreflect.BoolKind:
return !keys[i].Bool() && keys[j].Bool()
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind,
protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
return keys[i].Int() < keys[j].Int()
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind,
protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
return keys[i].Uint() < keys[j].Uint()
case protoreflect.StringKind:
return keys[i].String() < keys[j].String()
default:
panic("invalid kind: " + keyKind.String())
}
})
for _, key := range keys {
if !f(key, mapv.Get(key)) {
break
}
}
}

View File

@ -1,69 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mapsort_test
import (
"strconv"
"testing"
"google.golang.org/protobuf/internal/mapsort"
pref "google.golang.org/protobuf/reflect/protoreflect"
testpb "google.golang.org/protobuf/internal/testprotos/test"
)
func TestRange(t *testing.T) {
m := (&testpb.TestAllTypes{
MapBoolBool: map[bool]bool{
false: false,
true: true,
},
MapInt32Int32: map[int32]int32{
0: 0,
1: 1,
2: 2,
},
MapUint64Uint64: map[uint64]uint64{
0: 0,
1: 1,
2: 2,
},
MapStringString: map[string]string{
"0": "0",
"1": "1",
"2": "2",
},
}).ProtoReflect()
m.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
mapv := v.Map()
var got []pref.MapKey
mapsort.Range(mapv, fd.MapKey().Kind(), func(key pref.MapKey, _ pref.Value) bool {
got = append(got, key)
return true
})
for wanti, key := range got {
var goti int
switch x := mapv.Get(key).Interface().(type) {
case bool:
if x {
goti = 1
}
case int32:
goti = int(x)
case uint64:
goti = int(x)
case string:
goti, _ = strconv.Atoi(x)
default:
t.Fatalf("unhandled map value type %T", x)
}
if wanti != goti {
t.Errorf("out of order range over map field %v: %v", fd.FullName(), got)
break
}
}
return true
})
}

View File

@ -20,7 +20,7 @@ import (
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
@ -64,25 +64,8 @@ func appendMessage(b []byte, m protoreflect.Message) []byte {
return b2
}
var fds []protoreflect.FieldDescriptor
m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
fds = append(fds, fd)
return true
})
sort.Slice(fds, func(i, j int) bool {
fdi, fdj := fds[i], fds[j]
switch {
case !fdi.IsExtension() && !fdj.IsExtension():
return fdi.Index() < fdj.Index()
case fdi.IsExtension() && fdj.IsExtension():
return fdi.FullName() < fdj.FullName()
default:
return !fdi.IsExtension() && fdj.IsExtension()
}
})
b = append(b, '{')
for _, fd := range fds {
order.RangeFields(m, order.IndexNameFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
k := string(fd.Name())
if fd.IsExtension() {
k = string("[" + fd.FullName() + "]")
@ -90,9 +73,10 @@ func appendMessage(b []byte, m protoreflect.Message) []byte {
b = append(b, k...)
b = append(b, ':')
b = appendValue(b, m.Get(fd), fd)
b = appendValue(b, v, fd)
b = append(b, delim()...)
}
return true
})
b = appendUnknown(b, m.GetUnknown())
b = bytes.TrimRight(b, delim())
b = append(b, '}')
@ -247,19 +231,14 @@ func appendList(b []byte, v protoreflect.List, fd protoreflect.FieldDescriptor)
}
func appendMap(b []byte, v protoreflect.Map, fd protoreflect.FieldDescriptor) []byte {
var ks []protoreflect.MapKey
mapsort.Range(v, fd.MapKey().Kind(), func(k protoreflect.MapKey, _ protoreflect.Value) bool {
ks = append(ks, k)
return true
})
b = append(b, '{')
for _, k := range ks {
order.RangeEntries(v, order.GenericKeyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool {
b = appendValue(b, k.Value(), fd.MapKey())
b = append(b, ':')
b = appendValue(b, v.Get(k), fd.MapValue())
b = appendValue(b, v, fd.MapValue())
b = append(b, delim()...)
}
return true
})
b = bytes.TrimRight(b, delim())
b = append(b, '}')
return b

89
internal/order/order.go Normal file
View File

@ -0,0 +1,89 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package order
import (
pref "google.golang.org/protobuf/reflect/protoreflect"
)
// FieldOrder specifies the ordering to visit message fields.
// It is a function that reports whether x is ordered before y.
type FieldOrder func(x, y pref.FieldDescriptor) bool
var (
// AnyFieldOrder specifies no specific field ordering.
AnyFieldOrder FieldOrder = nil
// LegacyFieldOrder sorts fields in the same ordering as emitted by
// wire serialization in the github.com/golang/protobuf implementation.
LegacyFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
ox, oy := x.ContainingOneof(), y.ContainingOneof()
inOneof := func(od pref.OneofDescriptor) bool {
return od != nil && !od.IsSynthetic()
}
// Extension fields sort before non-extension fields.
if x.IsExtension() != y.IsExtension() {
return x.IsExtension() && !y.IsExtension()
}
// Fields not within a oneof sort before those within a oneof.
if inOneof(ox) != inOneof(oy) {
return !inOneof(ox) && inOneof(oy)
}
// Fields in disjoint oneof sets are sorted by declaration index.
if ox != nil && oy != nil && ox != oy {
return ox.Index() < oy.Index()
}
// Fields sorted by field number.
return x.Number() < y.Number()
}
// NumberFieldOrder sorts fields by their field number.
NumberFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
return x.Number() < y.Number()
}
// IndexNameFieldOrder sorts non-extension fields before extension fields.
// Non-extensions are sorted according to their declaration index.
// Extensions are sorted according to their full name.
IndexNameFieldOrder FieldOrder = func(x, y pref.FieldDescriptor) bool {
// Non-extension fields sort before extension fields.
if x.IsExtension() != y.IsExtension() {
return !x.IsExtension() && y.IsExtension()
}
// Extensions sorted by fullname.
if x.IsExtension() && y.IsExtension() {
return x.FullName() < y.FullName()
}
// Non-extensions sorted by declaration index.
return x.Index() < y.Index()
}
)
// KeyOrder specifies the ordering to visit map entries.
// It is a function that reports whether x is ordered before y.
type KeyOrder func(x, y pref.MapKey) bool
var (
// AnyKeyOrder specifies no specific key ordering.
AnyKeyOrder KeyOrder = nil
// GenericKeyOrder sorts false before true, numeric keys in ascending order,
// and strings in lexicographical ordering according to UTF-8 codepoints.
GenericKeyOrder KeyOrder = func(x, y pref.MapKey) bool {
switch x.Interface().(type) {
case bool:
return !x.Bool() && y.Bool()
case int32, int64:
return x.Int() < y.Int()
case uint32, uint64:
return x.Uint() < y.Uint()
case string:
return x.String() < y.String()
default:
panic("invalid map key type")
}
}
)

View File

@ -0,0 +1,175 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package order
import (
"math/rand"
"sort"
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
type fieldDesc struct {
index int
name protoreflect.FullName
number protoreflect.FieldNumber
extension bool
oneofIndex int // non-zero means within oneof; negative means synthetic
pref.FieldDescriptor
}
func (d fieldDesc) Index() int { return d.index }
func (d fieldDesc) Name() pref.Name { return d.name.Name() }
func (d fieldDesc) FullName() pref.FullName { return d.name }
func (d fieldDesc) Number() pref.FieldNumber { return d.number }
func (d fieldDesc) IsExtension() bool { return d.extension }
func (d fieldDesc) ContainingOneof() pref.OneofDescriptor {
switch {
case d.oneofIndex < 0:
return oneofDesc{index: -d.oneofIndex, synthetic: true}
case d.oneofIndex > 0:
return oneofDesc{index: +d.oneofIndex, synthetic: false}
default:
return nil
}
}
type oneofDesc struct {
index int
synthetic bool
pref.OneofDescriptor
}
func (d oneofDesc) Index() int { return d.index }
func (d oneofDesc) IsSynthetic() bool { return d.synthetic }
func TestFieldOrder(t *testing.T) {
tests := []struct {
label string
order FieldOrder
fields []fieldDesc
}{{
label: "LegacyFieldOrder",
order: LegacyFieldOrder,
fields: []fieldDesc{
// Extension fields sorted first by field number.
{number: 2, extension: true},
{number: 4, extension: true},
{number: 100, extension: true},
{number: 120, extension: true},
// Non-extension fields that are not within a oneof
// sorted next by field number.
{number: 1},
{number: 5, oneofIndex: -9}, // synthetic oneof
{number: 10},
{number: 11, oneofIndex: -10}, // synthetic oneof
{number: 12},
// Non-synthetic oneofs sorted last by index.
{number: 13, oneofIndex: 4},
{number: 3, oneofIndex: 5},
{number: 9, oneofIndex: 5},
{number: 7, oneofIndex: 8},
},
}, {
label: "NumberFieldOrder",
order: NumberFieldOrder,
fields: []fieldDesc{
{number: 1, index: 5, name: "c"},
{number: 2, index: 2, name: "b"},
{number: 3, index: 3, name: "d"},
{number: 5, index: 1, name: "a"},
{number: 7, index: 7, name: "e"},
},
}, {
label: "IndexNameFieldOrder",
order: IndexNameFieldOrder,
fields: []fieldDesc{
// Non-extension fields sorted first by index.
{index: 0, number: 5, name: "c"},
{index: 2, number: 2, name: "a"},
{index: 4, number: 4, name: "b"},
{index: 7, number: 6, name: "d"},
// Extension fields sorted last by full name.
{index: 3, number: 1, name: "d.a", extension: true},
{index: 5, number: 3, name: "e", extension: true},
{index: 1, number: 7, name: "g", extension: true},
},
}}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
want := tt.fields
got := append([]fieldDesc(nil), want...)
for i, j := range rand.Perm(len(got)) {
got[i], got[j] = got[j], got[i]
}
sort.Slice(got, func(i, j int) bool {
return tt.order(got[i], got[j])
})
if diff := cmp.Diff(want, got,
cmp.Comparer(func(x, y fieldDesc) bool { return x == y }),
); diff != "" {
t.Errorf("order mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestKeyOrder(t *testing.T) {
tests := []struct {
label string
order KeyOrder
keys []interface{}
}{{
label: "GenericKeyOrder",
order: GenericKeyOrder,
keys: []interface{}{false, true},
}, {
label: "GenericKeyOrder",
order: GenericKeyOrder,
keys: []interface{}{int32(-100), int32(-99), int32(-10), int32(-9), int32(-1), int32(0), int32(+1), int32(+9), int32(+10), int32(+99), int32(+100)},
}, {
label: "GenericKeyOrder",
order: GenericKeyOrder,
keys: []interface{}{int64(-100), int64(-99), int64(-10), int64(-9), int64(-1), int64(0), int64(+1), int64(+9), int64(+10), int64(+99), int64(+100)},
}, {
label: "GenericKeyOrder",
order: GenericKeyOrder,
keys: []interface{}{uint32(0), uint32(1), uint32(9), uint32(10), uint32(99), uint32(100)},
}, {
label: "GenericKeyOrder",
order: GenericKeyOrder,
keys: []interface{}{uint64(0), uint64(1), uint64(9), uint64(10), uint64(99), uint64(100)},
}, {
label: "GenericKeyOrder",
order: GenericKeyOrder,
keys: []interface{}{"", "a", "aa", "ab", "ba", "bb", "\u0080", "\u0080\u0081", "\u0082\u0080"},
}}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
var got, want []protoreflect.MapKey
for _, v := range tt.keys {
want = append(want, pref.ValueOf(v).MapKey())
}
got = append(got, want...)
for i, j := range rand.Perm(len(got)) {
got[i], got[j] = got[j], got[i]
}
sort.Slice(got, func(i, j int) bool {
return tt.order(got[i], got[j])
})
if diff := cmp.Diff(want, got, cmp.Transformer("", protoreflect.MapKey.Interface)); diff != "" {
t.Errorf("order mismatch (-want +got):\n%s", diff)
}
})
}
}

115
internal/order/range.go Normal file
View File

@ -0,0 +1,115 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package order provides ordered access to messages and maps.
package order
import (
"sort"
"sync"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
type messageField struct {
fd pref.FieldDescriptor
v pref.Value
}
var messageFieldPool = sync.Pool{
New: func() interface{} { return new([]messageField) },
}
type (
// FieldRnger is an interface for visiting all fields in a message.
// The protoreflect.Message type implements this interface.
FieldRanger interface{ Range(VisitField) }
// VisitField is called everytime a message field is visited.
VisitField = func(pref.FieldDescriptor, pref.Value) bool
)
// RangeFields iterates over the fields of fs according to the specified order.
func RangeFields(fs FieldRanger, less FieldOrder, fn VisitField) {
if less == nil {
fs.Range(fn)
return
}
// Obtain a pre-allocated scratch buffer.
p := messageFieldPool.Get().(*[]messageField)
fields := (*p)[:0]
defer func() {
if cap(fields) < 1024 {
*p = fields
messageFieldPool.Put(p)
}
}()
// Collect all fields in the message and sort them.
fs.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
fields = append(fields, messageField{fd, v})
return true
})
sort.Slice(fields, func(i, j int) bool {
return less(fields[i].fd, fields[j].fd)
})
// Visit the fields in the specified ordering.
for _, f := range fields {
if !fn(f.fd, f.v) {
return
}
}
}
type mapEntry struct {
k pref.MapKey
v pref.Value
}
var mapEntryPool = sync.Pool{
New: func() interface{} { return new([]mapEntry) },
}
type (
// EntryRanger is an interface for visiting all fields in a message.
// The protoreflect.Map type implements this interface.
EntryRanger interface{ Range(VisitEntry) }
// VisitEntry is called everytime a map entry is visited.
VisitEntry = func(pref.MapKey, pref.Value) bool
)
// RangeEntries iterates over the entries of es according to the specified order.
func RangeEntries(es EntryRanger, less KeyOrder, fn VisitEntry) {
if less == nil {
es.Range(fn)
return
}
// Obtain a pre-allocated scratch buffer.
p := mapEntryPool.Get().(*[]mapEntry)
entries := (*p)[:0]
defer func() {
if cap(entries) < 1024 {
*p = entries
mapEntryPool.Put(p)
}
}()
// Collect all entries in the map and sort them.
es.Range(func(k pref.MapKey, v pref.Value) bool {
entries = append(entries, mapEntry{k, v})
return true
})
sort.Slice(entries, func(i, j int) bool {
return less(entries[i].k, entries[j].k)
})
// Visit the entries in the specified ordering.
for _, e := range entries {
if !fn(e.k, e.v) {
return
}
}
}

View File

@ -5,12 +5,9 @@
package proto
import (
"sort"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/fieldsort"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
@ -211,14 +208,15 @@ func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]
if messageset.IsMessageSet(m.Descriptor()) {
return o.marshalMessageSet(b, m)
}
// There are many choices for what order we visit fields in. The default one here
// is chosen for reasonable efficiency and simplicity given the protoreflect API.
// It is not deterministic, since Message.Range does not return fields in any
// defined order.
//
// When using deterministic serialization, we sort the known fields.
fieldOrder := order.AnyFieldOrder
if o.Deterministic {
// TODO: This should use a more natural ordering like NumberFieldOrder,
// but doing so breaks golden tests that make invalid assumption about
// output stability of this implementation.
fieldOrder = order.LegacyFieldOrder
}
var err error
o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
b, err = o.marshalField(b, fd, v)
return err == nil
})
@ -229,27 +227,6 @@ func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]
return b, nil
}
// rangeFields visits fields in a defined order when deterministic serialization is enabled.
func (o MarshalOptions) rangeFields(m protoreflect.Message, f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
if !o.Deterministic {
m.Range(f)
return
}
var fds []protoreflect.FieldDescriptor
m.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
fds = append(fds, fd)
return true
})
sort.Slice(fds, func(a, b int) bool {
return fieldsort.Less(fds[a], fds[b])
})
for _, fd := range fds {
if !f(fd, m.Get(fd)) {
break
}
}
}
func (o MarshalOptions) marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
switch {
case fd.IsList():
@ -292,8 +269,12 @@ func (o MarshalOptions) marshalList(b []byte, fd protoreflect.FieldDescriptor, l
func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) {
keyf := fd.MapKey()
valf := fd.MapValue()
keyOrder := order.AnyKeyOrder
if o.Deterministic {
keyOrder = order.GenericKeyOrder
}
var err error
o.rangeMap(mapv, keyf.Kind(), func(key protoreflect.MapKey, value protoreflect.Value) bool {
order.RangeEntries(mapv, keyOrder, func(key protoreflect.MapKey, value protoreflect.Value) bool {
b = protowire.AppendTag(b, fd.Number(), protowire.BytesType)
var pos int
b, pos = appendSpeculativeLength(b)
@ -312,14 +293,6 @@ func (o MarshalOptions) marshalMap(b []byte, fd protoreflect.FieldDescriptor, ma
return b, err
}
func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
if !o.Deterministic {
mapv.Range(f)
return
}
mapsort.Range(mapv, kind, f)
}
// When encoding length-prefixed fields, we speculatively set aside some number of bytes
// for the length, encode the data, and then encode the length (shifting the data if necessary
// to make room).

View File

@ -9,6 +9,7 @@ import (
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
@ -28,8 +29,12 @@ func (o MarshalOptions) marshalMessageSet(b []byte, m protoreflect.Message) ([]b
if !flags.ProtoLegacy {
return b, errors.New("no support for message_set_wire_format")
}
fieldOrder := order.AnyFieldOrder
if o.Deterministic {
fieldOrder = order.NumberFieldOrder
}
var err error
o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
b, err = o.marshalMessageSetField(b, fd, v)
return err == nil
})