proto: add generic Size

Change-Id: I4ed123f4a9747fb4aba392bc5b9608d294bacc4d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/169697
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Damien Neil 2019-03-27 09:23:20 -07:00
parent 42577eaa4d
commit 61e93c70a2
6 changed files with 180 additions and 8 deletions

View File

@ -42,6 +42,7 @@ func main() {
writeSource("internal/prototype/protofile_list_gen.go", generateListTypes())
writeSource("proto/decode_gen.go", generateProtoDecode())
writeSource("proto/encode_gen.go", generateProtoEncode())
writeSource("proto/size_gen.go", generateProtoSize())
}
// chdirRoot changes the working directory to the repository root.

View File

@ -269,3 +269,30 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref
return b, nil
}
`))
func generateProtoSize() string {
return mustExecute(protoSizeTemplate, ProtoKinds)
}
var protoSizeTemplate = template.Must(template.New("").Parse(`
func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
switch kind {
{{- range .}}
case {{.Expr}}:
{{if (eq .Name "Message") -}}
return wire.SizeBytes(sizeMessage(v.Message()))
{{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
return wire.Size{{.WireType}}()
{{- else if (eq .WireType "Bytes") -}}
return wire.Size{{.WireType}}(len({{.FromValue}}))
{{- else if (eq .WireType "Group") -}}
return wire.Size{{.WireType}}(num, sizeMessage(v.Message()))
{{- else -}}
return wire.Size{{.WireType}}({{.FromValue}})
{{- end}}
{{- end}}
default:
return 0
}
}
`))

View File

@ -10,6 +10,7 @@ import (
"testing"
protoV1 "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/v2/encoding/textpb"
"github.com/golang/protobuf/v2/internal/encoding/pack"
"github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto"
@ -32,7 +33,7 @@ func TestDecode(t *testing.T) {
wire := append(([]byte)(nil), test.wire...)
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(wire, got); err != nil {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
@ -43,7 +44,7 @@ func TestDecode(t *testing.T) {
}
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
})
}
@ -901,3 +902,8 @@ func extend(desc *protoV1.ExtensionDesc, value interface{}) buildOpt {
}
}
}
func marshalText(m proto.Message) string {
b, _ := textpb.Marshal(m)
return string(b)
}

View File

@ -7,7 +7,6 @@ import (
"testing"
protoV1 "github.com/golang/protobuf/proto"
//_ "github.com/golang/protobuf/v2/internal/legacy"
"github.com/golang/protobuf/v2/proto"
"github.com/google/go-cmp/cmp"
)
@ -18,7 +17,12 @@ func TestEncode(t *testing.T) {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
wire, err := proto.Marshal(want)
if err != nil {
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
size := proto.Size(want)
if size != len(wire) {
t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
}
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
@ -41,12 +45,12 @@ func TestEncodeDeterministic(t *testing.T) {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
wire, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
if err != nil {
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
wire2, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
if err != nil {
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
if !bytes.Equal(wire, wire2) {
@ -55,12 +59,12 @@ func TestEncodeDeterministic(t *testing.T) {
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(wire, got); err != nil {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
})
}

79
proto/size.go Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proto
import (
"fmt"
"github.com/golang/protobuf/v2/internal/encoding/wire"
"github.com/golang/protobuf/v2/reflect/protoreflect"
)
// Size returns the size in bytes of the wire-format encoding of m.
func Size(m Message) int {
return sizeMessage(m.ProtoReflect())
}
func sizeMessage(m protoreflect.Message) (size int) {
fields := m.Type().Fields()
knownFields := m.KnownFields()
m.KnownFields().Range(func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
field := fields.ByNumber(num)
if field == nil {
field = knownFields.ExtensionTypes().ByNumber(num)
if field == nil {
panic(fmt.Errorf("no descriptor for field %d in %q", num, m.Type().FullName()))
}
}
size += sizeField(field, value)
return true
})
m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
size += len(raw)
return true
})
return size
}
func sizeField(field protoreflect.FieldDescriptor, value protoreflect.Value) (size int) {
num := field.Number()
kind := field.Kind()
switch {
case field.Cardinality() != protoreflect.Repeated:
return wire.SizeTag(num) + sizeSingular(num, kind, value)
case field.IsMap():
return sizeMap(num, kind, field.MessageType(), value.Map())
case field.IsPacked():
return sizePacked(num, kind, value.List())
default:
return sizeList(num, kind, value.List())
}
}
func sizeMap(num wire.Number, kind protoreflect.Kind, mdesc protoreflect.MessageDescriptor, mapv protoreflect.Map) (size int) {
keyf := mdesc.Fields().ByNumber(1)
valf := mdesc.Fields().ByNumber(2)
mapv.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool {
size += wire.SizeTag(num)
size += wire.SizeBytes(sizeField(keyf, key.Value()) + sizeField(valf, value))
return true
})
return size
}
func sizePacked(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
content := 0
for i, llen := 0, list.Len(); i < llen; i++ {
content += sizeSingular(num, kind, list.Get(i))
}
return wire.SizeTag(num) + wire.SizeBytes(content)
}
func sizeList(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
for i, llen := 0, list.Len(); i < llen; i++ {
size += wire.SizeTag(num) + sizeSingular(num, kind, list.Get(i))
}
return size
}

55
proto/size_gen.go Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style.
// license that can be found in the LICENSE file.
// Code generated by generate-types. DO NOT EDIT.
package proto
import (
"github.com/golang/protobuf/v2/internal/encoding/wire"
"github.com/golang/protobuf/v2/reflect/protoreflect"
)
func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
switch kind {
case protoreflect.BoolKind:
return wire.SizeVarint(wire.EncodeBool(v.Bool()))
case protoreflect.EnumKind:
return wire.SizeVarint(uint64(v.Enum()))
case protoreflect.Int32Kind:
return wire.SizeVarint(uint64(int32(v.Int())))
case protoreflect.Sint32Kind:
return wire.SizeVarint(wire.EncodeZigZag(int64(int32(v.Int()))))
case protoreflect.Uint32Kind:
return wire.SizeVarint(uint64(uint32(v.Uint())))
case protoreflect.Int64Kind:
return wire.SizeVarint(uint64(v.Int()))
case protoreflect.Sint64Kind:
return wire.SizeVarint(wire.EncodeZigZag(v.Int()))
case protoreflect.Uint64Kind:
return wire.SizeVarint(v.Uint())
case protoreflect.Sfixed32Kind:
return wire.SizeFixed32()
case protoreflect.Fixed32Kind:
return wire.SizeFixed32()
case protoreflect.FloatKind:
return wire.SizeFixed32()
case protoreflect.Sfixed64Kind:
return wire.SizeFixed64()
case protoreflect.Fixed64Kind:
return wire.SizeFixed64()
case protoreflect.DoubleKind:
return wire.SizeFixed64()
case protoreflect.StringKind:
return wire.SizeBytes(len([]byte(v.String())))
case protoreflect.BytesKind:
return wire.SizeBytes(len(v.Bytes()))
case protoreflect.MessageKind:
return wire.SizeBytes(sizeMessage(v.Message()))
case protoreflect.GroupKind:
return wire.SizeGroup(num, sizeMessage(v.Message()))
default:
return 0
}
}