mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2025-02-06 00:40:02 +00:00
proto: add generic Size
Change-Id: I4ed123f4a9747fb4aba392bc5b9608d294bacc4d Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/169697 Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
parent
42577eaa4d
commit
61e93c70a2
@ -42,6 +42,7 @@ func main() {
|
||||
writeSource("internal/prototype/protofile_list_gen.go", generateListTypes())
|
||||
writeSource("proto/decode_gen.go", generateProtoDecode())
|
||||
writeSource("proto/encode_gen.go", generateProtoEncode())
|
||||
writeSource("proto/size_gen.go", generateProtoSize())
|
||||
}
|
||||
|
||||
// chdirRoot changes the working directory to the repository root.
|
||||
|
@ -269,3 +269,30 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref
|
||||
return b, nil
|
||||
}
|
||||
`))
|
||||
|
||||
func generateProtoSize() string {
|
||||
return mustExecute(protoSizeTemplate, ProtoKinds)
|
||||
}
|
||||
|
||||
var protoSizeTemplate = template.Must(template.New("").Parse(`
|
||||
func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
|
||||
switch kind {
|
||||
{{- range .}}
|
||||
case {{.Expr}}:
|
||||
{{if (eq .Name "Message") -}}
|
||||
return wire.SizeBytes(sizeMessage(v.Message()))
|
||||
{{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
|
||||
return wire.Size{{.WireType}}()
|
||||
{{- else if (eq .WireType "Bytes") -}}
|
||||
return wire.Size{{.WireType}}(len({{.FromValue}}))
|
||||
{{- else if (eq .WireType "Group") -}}
|
||||
return wire.Size{{.WireType}}(num, sizeMessage(v.Message()))
|
||||
{{- else -}}
|
||||
return wire.Size{{.WireType}}({{.FromValue}})
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
`))
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
protoV1 "github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/v2/encoding/textpb"
|
||||
"github.com/golang/protobuf/v2/internal/encoding/pack"
|
||||
"github.com/golang/protobuf/v2/internal/scalar"
|
||||
"github.com/golang/protobuf/v2/proto"
|
||||
@ -32,7 +33,7 @@ func TestDecode(t *testing.T) {
|
||||
wire := append(([]byte)(nil), test.wire...)
|
||||
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
|
||||
if err := proto.Unmarshal(wire, got); err != nil {
|
||||
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
|
||||
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
|
||||
return
|
||||
}
|
||||
|
||||
@ -43,7 +44,7 @@ func TestDecode(t *testing.T) {
|
||||
}
|
||||
|
||||
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
|
||||
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
|
||||
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -901,3 +902,8 @@ func extend(desc *protoV1.ExtensionDesc, value interface{}) buildOpt {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func marshalText(m proto.Message) string {
|
||||
b, _ := textpb.Marshal(m)
|
||||
return string(b)
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
|
||||
protoV1 "github.com/golang/protobuf/proto"
|
||||
//_ "github.com/golang/protobuf/v2/internal/legacy"
|
||||
"github.com/golang/protobuf/v2/proto"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
@ -18,7 +17,12 @@ func TestEncode(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
|
||||
wire, err := proto.Marshal(want)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
|
||||
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
|
||||
}
|
||||
|
||||
size := proto.Size(want)
|
||||
if size != len(wire) {
|
||||
t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
|
||||
}
|
||||
|
||||
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
|
||||
@ -41,12 +45,12 @@ func TestEncodeDeterministic(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
|
||||
wire, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
|
||||
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
|
||||
}
|
||||
|
||||
wire2, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
|
||||
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
|
||||
}
|
||||
|
||||
if !bytes.Equal(wire, wire2) {
|
||||
@ -55,12 +59,12 @@ func TestEncodeDeterministic(t *testing.T) {
|
||||
|
||||
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
|
||||
if err := proto.Unmarshal(wire, got); err != nil {
|
||||
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
|
||||
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
|
||||
return
|
||||
}
|
||||
|
||||
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
|
||||
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
|
||||
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
79
proto/size.go
Normal file
79
proto/size.go
Normal file
@ -0,0 +1,79 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/v2/internal/encoding/wire"
|
||||
"github.com/golang/protobuf/v2/reflect/protoreflect"
|
||||
)
|
||||
|
||||
// Size returns the size in bytes of the wire-format encoding of m.
|
||||
func Size(m Message) int {
|
||||
return sizeMessage(m.ProtoReflect())
|
||||
}
|
||||
|
||||
func sizeMessage(m protoreflect.Message) (size int) {
|
||||
fields := m.Type().Fields()
|
||||
knownFields := m.KnownFields()
|
||||
m.KnownFields().Range(func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
|
||||
field := fields.ByNumber(num)
|
||||
if field == nil {
|
||||
field = knownFields.ExtensionTypes().ByNumber(num)
|
||||
if field == nil {
|
||||
panic(fmt.Errorf("no descriptor for field %d in %q", num, m.Type().FullName()))
|
||||
}
|
||||
}
|
||||
size += sizeField(field, value)
|
||||
return true
|
||||
})
|
||||
m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
|
||||
size += len(raw)
|
||||
return true
|
||||
})
|
||||
return size
|
||||
}
|
||||
|
||||
func sizeField(field protoreflect.FieldDescriptor, value protoreflect.Value) (size int) {
|
||||
num := field.Number()
|
||||
kind := field.Kind()
|
||||
switch {
|
||||
case field.Cardinality() != protoreflect.Repeated:
|
||||
return wire.SizeTag(num) + sizeSingular(num, kind, value)
|
||||
case field.IsMap():
|
||||
return sizeMap(num, kind, field.MessageType(), value.Map())
|
||||
case field.IsPacked():
|
||||
return sizePacked(num, kind, value.List())
|
||||
default:
|
||||
return sizeList(num, kind, value.List())
|
||||
}
|
||||
}
|
||||
|
||||
func sizeMap(num wire.Number, kind protoreflect.Kind, mdesc protoreflect.MessageDescriptor, mapv protoreflect.Map) (size int) {
|
||||
keyf := mdesc.Fields().ByNumber(1)
|
||||
valf := mdesc.Fields().ByNumber(2)
|
||||
mapv.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool {
|
||||
size += wire.SizeTag(num)
|
||||
size += wire.SizeBytes(sizeField(keyf, key.Value()) + sizeField(valf, value))
|
||||
return true
|
||||
})
|
||||
return size
|
||||
}
|
||||
|
||||
func sizePacked(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
|
||||
content := 0
|
||||
for i, llen := 0, list.Len(); i < llen; i++ {
|
||||
content += sizeSingular(num, kind, list.Get(i))
|
||||
}
|
||||
return wire.SizeTag(num) + wire.SizeBytes(content)
|
||||
}
|
||||
|
||||
func sizeList(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) {
|
||||
for i, llen := 0, list.Len(); i < llen; i++ {
|
||||
size += wire.SizeTag(num) + sizeSingular(num, kind, list.Get(i))
|
||||
}
|
||||
return size
|
||||
}
|
55
proto/size_gen.go
Normal file
55
proto/size_gen.go
Normal file
@ -0,0 +1,55 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style.
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Code generated by generate-types. DO NOT EDIT.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"github.com/golang/protobuf/v2/internal/encoding/wire"
|
||||
"github.com/golang/protobuf/v2/reflect/protoreflect"
|
||||
)
|
||||
|
||||
func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
|
||||
switch kind {
|
||||
case protoreflect.BoolKind:
|
||||
return wire.SizeVarint(wire.EncodeBool(v.Bool()))
|
||||
case protoreflect.EnumKind:
|
||||
return wire.SizeVarint(uint64(v.Enum()))
|
||||
case protoreflect.Int32Kind:
|
||||
return wire.SizeVarint(uint64(int32(v.Int())))
|
||||
case protoreflect.Sint32Kind:
|
||||
return wire.SizeVarint(wire.EncodeZigZag(int64(int32(v.Int()))))
|
||||
case protoreflect.Uint32Kind:
|
||||
return wire.SizeVarint(uint64(uint32(v.Uint())))
|
||||
case protoreflect.Int64Kind:
|
||||
return wire.SizeVarint(uint64(v.Int()))
|
||||
case protoreflect.Sint64Kind:
|
||||
return wire.SizeVarint(wire.EncodeZigZag(v.Int()))
|
||||
case protoreflect.Uint64Kind:
|
||||
return wire.SizeVarint(v.Uint())
|
||||
case protoreflect.Sfixed32Kind:
|
||||
return wire.SizeFixed32()
|
||||
case protoreflect.Fixed32Kind:
|
||||
return wire.SizeFixed32()
|
||||
case protoreflect.FloatKind:
|
||||
return wire.SizeFixed32()
|
||||
case protoreflect.Sfixed64Kind:
|
||||
return wire.SizeFixed64()
|
||||
case protoreflect.Fixed64Kind:
|
||||
return wire.SizeFixed64()
|
||||
case protoreflect.DoubleKind:
|
||||
return wire.SizeFixed64()
|
||||
case protoreflect.StringKind:
|
||||
return wire.SizeBytes(len([]byte(v.String())))
|
||||
case protoreflect.BytesKind:
|
||||
return wire.SizeBytes(len(v.Bytes()))
|
||||
case protoreflect.MessageKind:
|
||||
return wire.SizeBytes(sizeMessage(v.Message()))
|
||||
case protoreflect.GroupKind:
|
||||
return wire.SizeGroup(num, sizeMessage(v.Message()))
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user