proto: return an error instead of producing invalid wire format

There currently is no risk of producing invalid wire format,
but that will change with subsequent changes regarding lazy decoding.

We have been running this change in production for about a month,
without ever triggering the check (until lazy decoding is involved).

related to golang/protobuf#1609

Change-Id: I3c5c956aee2fa81f99dea03ed2a977a1547081fc
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579595
Auto-Submit: Michael Stapelberg <stapelberg@google.com>
Reviewed-by: Lasse Folger <lassefolger@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Michael Stapelberg 2024-04-17 16:21:00 +02:00 committed by Gopher Robot
parent 671c2db939
commit 94bb78c93b
5 changed files with 71 additions and 7 deletions

View File

@ -87,3 +87,18 @@ func InvalidUTF8(name string) error {
func RequiredNotSet(name string) error {
return New("required field %v not set", name)
}
type SizeMismatchError struct {
Calculated, Measured int
}
func (e *SizeMismatchError) Error() string {
return fmt.Sprintf("size mismatch (see https://github.com/golang/protobuf/issues/1609): calculated=%d, measured=%d", e.Calculated, e.Measured)
}
func MismatchedSizeCalculation(calculated, measured int) error {
return &SizeMismatchError{
Calculated: calculated,
Measured: measured,
}
}

View File

@ -233,9 +233,15 @@ func sizeMessageInfo(p pointer, f *coderFieldInfo, opts marshalOptions) int {
}
func appendMessageInfo(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
calculatedSize := f.mi.sizePointer(p.Elem(), opts)
b = protowire.AppendVarint(b, f.wiretag)
b = protowire.AppendVarint(b, uint64(f.mi.sizePointer(p.Elem(), opts)))
return f.mi.marshalAppendPointer(b, p.Elem(), opts)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := f.mi.marshalAppendPointer(b, p.Elem(), opts)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}
func consumeMessageInfo(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
@ -267,9 +273,15 @@ func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
}
func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
calculatedSize := proto.Size(m)
b = protowire.AppendVarint(b, wiretag)
b = protowire.AppendVarint(b, uint64(proto.Size(m)))
return opts.Options().MarshalAppend(b, m)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := opts.Options().MarshalAppend(b, m)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
return b, err
}
func consumeMessage(b []byte, m proto.Message, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
@ -482,10 +494,14 @@ func appendMessageSliceInfo(b []byte, p pointer, f *coderFieldInfo, opts marshal
b = protowire.AppendVarint(b, f.wiretag)
siz := f.mi.sizePointer(v, opts)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = f.mi.marshalAppendPointer(b, v, opts)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
@ -538,10 +554,14 @@ func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
b, err = opts.Options().MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}
@ -599,11 +619,15 @@ func appendMessageSliceValue(b []byte, listv protoreflect.Value, wiretag uint64,
b = protowire.AppendVarint(b, wiretag)
siz := proto.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
var err error
b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
if measuredSize := len(b) - before; siz != measuredSize {
return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
}
}
return b, nil
}

View File

@ -9,6 +9,7 @@ import (
"sort"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/reflect/protoreflect"
)
@ -240,11 +241,16 @@ func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coder
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapi.valFuncs.size(val, mapValTagSize, opts)
b = protowire.AppendVarint(b, uint64(size))
before := len(b)
b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
return nil, err
}
return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
if measuredSize := len(b) - before; size != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(size, measuredSize)
}
return b, err
} else {
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
val := pointerOfValue(valrv)
@ -259,7 +265,12 @@ func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coder
}
b = protowire.AppendVarint(b, mapi.valWiretag)
b = protowire.AppendVarint(b, uint64(valSize))
return f.mi.marshalAppendPointer(b, val, opts)
before := len(b)
b, err = f.mi.marshalAppendPointer(b, val, opts)
if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
}
return b, err
}
}

View File

@ -5,12 +5,17 @@
package proto
import (
"errors"
"fmt"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/order"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
protoerrors "google.golang.org/protobuf/internal/errors"
)
// MarshalOptions configures the marshaler.
@ -175,6 +180,10 @@ func (o MarshalOptions) marshal(b []byte, m protoreflect.Message) (out protoifac
out.Buf, err = o.marshalMessageSlow(b, m)
}
if err != nil {
var mismatch *protoerrors.SizeMismatchError
if errors.As(err, &mismatch) {
return out, fmt.Errorf("marshaling %s: %v", string(m.Descriptor().FullName()), err)
}
return out, err
}
if allowPartial {

View File

@ -47,11 +47,16 @@ func (o MarshalOptions) marshalMessageSet(b []byte, m protoreflect.Message) ([]b
func (o MarshalOptions) marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) {
b = messageset.AppendFieldStart(b, fd.Number())
b = protowire.AppendTag(b, messageset.FieldMessage, protowire.BytesType)
b = protowire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
calculatedSize := o.Size(value.Message().Interface())
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
b, err := o.marshalMessage(b, value.Message())
if err != nil {
return b, err
}
if measuredSize := len(b) - before; calculatedSize != measuredSize {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
b = messageset.AppendFieldEnd(b)
return b, nil
}