diff --git a/internal/cmd/generate-types/main.go b/internal/cmd/generate-types/main.go index ac727e89..2f738722 100644 --- a/internal/cmd/generate-types/main.go +++ b/internal/cmd/generate-types/main.go @@ -42,6 +42,7 @@ func main() { writeSource("internal/prototype/protofile_list_gen.go", generateListTypes()) writeSource("proto/decode_gen.go", generateProtoDecode()) writeSource("proto/encode_gen.go", generateProtoEncode()) + writeSource("proto/size_gen.go", generateProtoSize()) } // chdirRoot changes the working directory to the repository root. diff --git a/internal/cmd/generate-types/proto.go b/internal/cmd/generate-types/proto.go index dea9581e..7d688cd1 100644 --- a/internal/cmd/generate-types/proto.go +++ b/internal/cmd/generate-types/proto.go @@ -269,3 +269,30 @@ func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoref return b, nil } `)) + +func generateProtoSize() string { + return mustExecute(protoSizeTemplate, ProtoKinds) +} + +var protoSizeTemplate = template.Must(template.New("").Parse(` +func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int { + switch kind { + {{- range .}} + case {{.Expr}}: + {{if (eq .Name "Message") -}} + return wire.SizeBytes(sizeMessage(v.Message())) + {{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}} + return wire.Size{{.WireType}}() + {{- else if (eq .WireType "Bytes") -}} + return wire.Size{{.WireType}}(len({{.FromValue}})) + {{- else if (eq .WireType "Group") -}} + return wire.Size{{.WireType}}(num, sizeMessage(v.Message())) + {{- else -}} + return wire.Size{{.WireType}}({{.FromValue}}) + {{- end}} + {{- end}} + default: + return 0 + } +} +`)) diff --git a/proto/decode_test.go b/proto/decode_test.go index 8dc218b3..feb4ac64 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -10,6 +10,7 @@ import ( "testing" protoV1 "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/v2/encoding/textpb" "github.com/golang/protobuf/v2/internal/encoding/pack" "github.com/golang/protobuf/v2/internal/scalar" "github.com/golang/protobuf/v2/proto" @@ -32,7 +33,7 @@ func TestDecode(t *testing.T) { wire := append(([]byte)(nil), test.wire...) got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message) if err := proto.Unmarshal(wire, got); err != nil { - t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message))) + t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want)) return } @@ -43,7 +44,7 @@ func TestDecode(t *testing.T) { } if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) { - t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message))) + t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want)) } }) } @@ -901,3 +902,8 @@ func extend(desc *protoV1.ExtensionDesc, value interface{}) buildOpt { } } } + +func marshalText(m proto.Message) string { + b, _ := textpb.Marshal(m) + return string(b) +} diff --git a/proto/encode_test.go b/proto/encode_test.go index 4c5034d0..d467b74d 100644 --- a/proto/encode_test.go +++ b/proto/encode_test.go @@ -7,7 +7,6 @@ import ( "testing" protoV1 "github.com/golang/protobuf/proto" - //_ "github.com/golang/protobuf/v2/internal/legacy" "github.com/golang/protobuf/v2/proto" "github.com/google/go-cmp/cmp" ) @@ -18,7 +17,12 @@ func TestEncode(t *testing.T) { t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { wire, err := proto.Marshal(want) if err != nil { - t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message))) + t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want)) + } + + size := proto.Size(want) + if size != len(wire) { + t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want)) } got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message) @@ -41,12 +45,12 @@ func TestEncodeDeterministic(t *testing.T) { t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) { wire, err := proto.MarshalOptions{Deterministic: true}.Marshal(want) if err != nil { - t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message))) + t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want)) } wire2, err := proto.MarshalOptions{Deterministic: true}.Marshal(want) if err != nil { - t.Fatalf("Marshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message))) + t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want)) } if !bytes.Equal(wire, wire2) { @@ -55,12 +59,12 @@ func TestEncodeDeterministic(t *testing.T) { got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message) if err := proto.Unmarshal(wire, got); err != nil { - t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message))) + t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want)) return } if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) { - t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message))) + t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want)) } }) } diff --git a/proto/size.go b/proto/size.go new file mode 100644 index 00000000..8c9263fb --- /dev/null +++ b/proto/size.go @@ -0,0 +1,79 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proto + +import ( + "fmt" + + "github.com/golang/protobuf/v2/internal/encoding/wire" + "github.com/golang/protobuf/v2/reflect/protoreflect" +) + +// Size returns the size in bytes of the wire-format encoding of m. +func Size(m Message) int { + return sizeMessage(m.ProtoReflect()) +} + +func sizeMessage(m protoreflect.Message) (size int) { + fields := m.Type().Fields() + knownFields := m.KnownFields() + m.KnownFields().Range(func(num protoreflect.FieldNumber, value protoreflect.Value) bool { + field := fields.ByNumber(num) + if field == nil { + field = knownFields.ExtensionTypes().ByNumber(num) + if field == nil { + panic(fmt.Errorf("no descriptor for field %d in %q", num, m.Type().FullName())) + } + } + size += sizeField(field, value) + return true + }) + m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool { + size += len(raw) + return true + }) + return size +} + +func sizeField(field protoreflect.FieldDescriptor, value protoreflect.Value) (size int) { + num := field.Number() + kind := field.Kind() + switch { + case field.Cardinality() != protoreflect.Repeated: + return wire.SizeTag(num) + sizeSingular(num, kind, value) + case field.IsMap(): + return sizeMap(num, kind, field.MessageType(), value.Map()) + case field.IsPacked(): + return sizePacked(num, kind, value.List()) + default: + return sizeList(num, kind, value.List()) + } +} + +func sizeMap(num wire.Number, kind protoreflect.Kind, mdesc protoreflect.MessageDescriptor, mapv protoreflect.Map) (size int) { + keyf := mdesc.Fields().ByNumber(1) + valf := mdesc.Fields().ByNumber(2) + mapv.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool { + size += wire.SizeTag(num) + size += wire.SizeBytes(sizeField(keyf, key.Value()) + sizeField(valf, value)) + return true + }) + return size +} + +func sizePacked(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) { + content := 0 + for i, llen := 0, list.Len(); i < llen; i++ { + content += sizeSingular(num, kind, list.Get(i)) + } + return wire.SizeTag(num) + wire.SizeBytes(content) +} + +func sizeList(num wire.Number, kind protoreflect.Kind, list protoreflect.List) (size int) { + for i, llen := 0, list.Len(); i < llen; i++ { + size += wire.SizeTag(num) + sizeSingular(num, kind, list.Get(i)) + } + return size +} diff --git a/proto/size_gen.go b/proto/size_gen.go new file mode 100644 index 00000000..d71c7c74 --- /dev/null +++ b/proto/size_gen.go @@ -0,0 +1,55 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style. +// license that can be found in the LICENSE file. + +// Code generated by generate-types. DO NOT EDIT. + +package proto + +import ( + "github.com/golang/protobuf/v2/internal/encoding/wire" + "github.com/golang/protobuf/v2/reflect/protoreflect" +) + +func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int { + switch kind { + case protoreflect.BoolKind: + return wire.SizeVarint(wire.EncodeBool(v.Bool())) + case protoreflect.EnumKind: + return wire.SizeVarint(uint64(v.Enum())) + case protoreflect.Int32Kind: + return wire.SizeVarint(uint64(int32(v.Int()))) + case protoreflect.Sint32Kind: + return wire.SizeVarint(wire.EncodeZigZag(int64(int32(v.Int())))) + case protoreflect.Uint32Kind: + return wire.SizeVarint(uint64(uint32(v.Uint()))) + case protoreflect.Int64Kind: + return wire.SizeVarint(uint64(v.Int())) + case protoreflect.Sint64Kind: + return wire.SizeVarint(wire.EncodeZigZag(v.Int())) + case protoreflect.Uint64Kind: + return wire.SizeVarint(v.Uint()) + case protoreflect.Sfixed32Kind: + return wire.SizeFixed32() + case protoreflect.Fixed32Kind: + return wire.SizeFixed32() + case protoreflect.FloatKind: + return wire.SizeFixed32() + case protoreflect.Sfixed64Kind: + return wire.SizeFixed64() + case protoreflect.Fixed64Kind: + return wire.SizeFixed64() + case protoreflect.DoubleKind: + return wire.SizeFixed64() + case protoreflect.StringKind: + return wire.SizeBytes(len([]byte(v.String()))) + case protoreflect.BytesKind: + return wire.SizeBytes(len(v.Bytes())) + case protoreflect.MessageKind: + return wire.SizeBytes(sizeMessage(v.Message())) + case protoreflect.GroupKind: + return wire.SizeGroup(num, sizeMessage(v.Message())) + default: + return 0 + } +}