mirror of
https://github.com/protocolbuffers/protobuf-go.git
synced 2024-12-27 15:26:51 +00:00
d30e561d9e
Add functions to the proto package which plumb through the fast-path state. As a sample use case: A followup CL adds an Initialized field to protoiface.UnmarshalOutput, permitting the unmarshaller to report back when it can confirm that a message is fully initialized. We want to preserve that information when an unmarshal operation threads through the proto package (such as when unmarshaling extensions). To allow these functions to be added as methods of MarshalOptions and UnmarshalOptions rather than top-level functions, separate the options from the input structs. Also update options passed to fast-path methods to set AllowPartial and Merge to reflect the expected behavior of those methods. (Always allow partial, never merge.) Change-Id: I482477b0c9340793be533e75a86d0bb88708716a Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215877 Reviewed-by: Joe Tsai <joetsai@google.com>
266 lines
7.4 KiB
Go
266 lines
7.4 KiB
Go
// Copyright 2018 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style.
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package proto
|
|
|
|
import (
|
|
"google.golang.org/protobuf/internal/encoding/messageset"
|
|
"google.golang.org/protobuf/internal/encoding/wire"
|
|
"google.golang.org/protobuf/internal/errors"
|
|
"google.golang.org/protobuf/internal/flags"
|
|
"google.golang.org/protobuf/internal/pragma"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/reflect/protoregistry"
|
|
"google.golang.org/protobuf/runtime/protoiface"
|
|
)
|
|
|
|
// UnmarshalOptions configures the unmarshaler.
|
|
//
|
|
// Example usage:
|
|
// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
|
|
type UnmarshalOptions struct {
|
|
pragma.NoUnkeyedLiterals
|
|
|
|
// Merge merges the input into the destination message.
|
|
// The default behavior is to always reset the message before unmarshaling,
|
|
// unless Merge is specified.
|
|
Merge bool
|
|
|
|
// AllowPartial accepts input for messages that will result in missing
|
|
// required fields. If AllowPartial is false (the default), Unmarshal will
|
|
// return an error if there are any missing required fields.
|
|
AllowPartial bool
|
|
|
|
// If DiscardUnknown is set, unknown fields are ignored.
|
|
DiscardUnknown bool
|
|
|
|
// Resolver is used for looking up types when unmarshaling extension fields.
|
|
// If nil, this defaults to using protoregistry.GlobalTypes.
|
|
Resolver interface {
|
|
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
|
|
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
|
|
}
|
|
}
|
|
|
|
var _ = protoiface.UnmarshalOptions(UnmarshalOptions{})
|
|
|
|
// Unmarshal parses the wire-format message in b and places the result in m.
|
|
func Unmarshal(b []byte, m Message) error {
|
|
_, err := UnmarshalOptions{}.unmarshal(b, m)
|
|
return err
|
|
}
|
|
|
|
// Unmarshal parses the wire-format message in b and places the result in m.
|
|
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
|
|
_, err := o.unmarshal(b, m)
|
|
return err
|
|
}
|
|
|
|
// UnmarshalState parses a wire-format message and places the result in m.
|
|
//
|
|
// This method permits fine-grained control over the unmarshaler.
|
|
// Most users should use Unmarshal instead.
|
|
func (o UnmarshalOptions) UnmarshalState(m Message, in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
|
|
return o.unmarshal(in.Buf, m)
|
|
}
|
|
|
|
func (o UnmarshalOptions) unmarshal(b []byte, message Message) (out protoiface.UnmarshalOutput, err error) {
|
|
if o.Resolver == nil {
|
|
o.Resolver = protoregistry.GlobalTypes
|
|
}
|
|
if !o.Merge {
|
|
Reset(message)
|
|
}
|
|
allowPartial := o.AllowPartial
|
|
o.Merge = true
|
|
o.AllowPartial = true
|
|
m := message.ProtoReflect()
|
|
methods := protoMethods(m)
|
|
if methods != nil && methods.Unmarshal != nil &&
|
|
!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
|
|
out, err = methods.Unmarshal(m, protoiface.UnmarshalInput{
|
|
Buf: b,
|
|
}, protoiface.UnmarshalOptions(o))
|
|
} else {
|
|
err = o.unmarshalMessageSlow(b, m)
|
|
}
|
|
if err != nil {
|
|
return out, err
|
|
}
|
|
if allowPartial {
|
|
return out, nil
|
|
}
|
|
return out, isInitialized(m)
|
|
}
|
|
|
|
func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
|
|
_, err := o.unmarshal(b, m.Interface())
|
|
return err
|
|
}
|
|
|
|
func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
|
|
md := m.Descriptor()
|
|
if messageset.IsMessageSet(md) {
|
|
return unmarshalMessageSet(b, m, o)
|
|
}
|
|
fields := md.Fields()
|
|
for len(b) > 0 {
|
|
// Parse the tag (field number and wire type).
|
|
num, wtyp, tagLen := wire.ConsumeTag(b)
|
|
if tagLen < 0 {
|
|
return wire.ParseError(tagLen)
|
|
}
|
|
if num > wire.MaxValidNumber {
|
|
return errors.New("invalid field number")
|
|
}
|
|
|
|
// Find the field descriptor for this field number.
|
|
fd := fields.ByNumber(num)
|
|
if fd == nil && md.ExtensionRanges().Has(num) {
|
|
extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
|
|
if err != nil && err != protoregistry.NotFound {
|
|
return err
|
|
}
|
|
if extType != nil {
|
|
fd = extType.TypeDescriptor()
|
|
}
|
|
}
|
|
var err error
|
|
if fd == nil {
|
|
err = errUnknown
|
|
} else if flags.ProtoLegacy {
|
|
if fd.IsWeak() && fd.Message().IsPlaceholder() {
|
|
err = errUnknown // weak referent is not linked in
|
|
}
|
|
}
|
|
|
|
// Parse the field value.
|
|
var valLen int
|
|
switch {
|
|
case err != nil:
|
|
case fd.IsList():
|
|
valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
|
|
case fd.IsMap():
|
|
valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
|
|
default:
|
|
valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
|
|
}
|
|
if err != nil {
|
|
if err != errUnknown {
|
|
return err
|
|
}
|
|
valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
|
|
if valLen < 0 {
|
|
return wire.ParseError(valLen)
|
|
}
|
|
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
|
|
}
|
|
b = b[tagLen+valLen:]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp wire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
|
|
v, n, err := o.unmarshalScalar(b, wtyp, fd)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
switch fd.Kind() {
|
|
case protoreflect.GroupKind, protoreflect.MessageKind:
|
|
m2 := m.Mutable(fd).Message()
|
|
if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
|
|
return n, err
|
|
}
|
|
default:
|
|
// Non-message scalars replace the previous value.
|
|
m.Set(fd, v)
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
|
|
if wtyp != wire.BytesType {
|
|
return 0, errUnknown
|
|
}
|
|
b, n = wire.ConsumeBytes(b)
|
|
if n < 0 {
|
|
return 0, wire.ParseError(n)
|
|
}
|
|
var (
|
|
keyField = fd.MapKey()
|
|
valField = fd.MapValue()
|
|
key protoreflect.Value
|
|
val protoreflect.Value
|
|
haveKey bool
|
|
haveVal bool
|
|
)
|
|
switch valField.Kind() {
|
|
case protoreflect.GroupKind, protoreflect.MessageKind:
|
|
val = mapv.NewValue()
|
|
}
|
|
// Map entries are represented as a two-element message with fields
|
|
// containing the key and value.
|
|
for len(b) > 0 {
|
|
num, wtyp, n := wire.ConsumeTag(b)
|
|
if n < 0 {
|
|
return 0, wire.ParseError(n)
|
|
}
|
|
if num > wire.MaxValidNumber {
|
|
return 0, errors.New("invalid field number")
|
|
}
|
|
b = b[n:]
|
|
err = errUnknown
|
|
switch num {
|
|
case 1:
|
|
key, n, err = o.unmarshalScalar(b, wtyp, keyField)
|
|
if err != nil {
|
|
break
|
|
}
|
|
haveKey = true
|
|
case 2:
|
|
var v protoreflect.Value
|
|
v, n, err = o.unmarshalScalar(b, wtyp, valField)
|
|
if err != nil {
|
|
break
|
|
}
|
|
switch valField.Kind() {
|
|
case protoreflect.GroupKind, protoreflect.MessageKind:
|
|
if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
|
|
return 0, err
|
|
}
|
|
default:
|
|
val = v
|
|
}
|
|
haveVal = true
|
|
}
|
|
if err == errUnknown {
|
|
n = wire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return 0, wire.ParseError(n)
|
|
}
|
|
} else if err != nil {
|
|
return 0, err
|
|
}
|
|
b = b[n:]
|
|
}
|
|
// Every map entry should have entries for key and value, but this is not strictly required.
|
|
if !haveKey {
|
|
key = keyField.Default()
|
|
}
|
|
if !haveVal {
|
|
switch valField.Kind() {
|
|
case protoreflect.GroupKind, protoreflect.MessageKind:
|
|
default:
|
|
val = valField.Default()
|
|
}
|
|
}
|
|
mapv.Set(key.MapKey(), val)
|
|
return n, nil
|
|
}
|
|
|
|
// errUnknown is used internally to indicate fields which should be added
|
|
// to the unknown field set of a message. It is never returned from an exported
|
|
// function.
|
|
var errUnknown = errors.New("BUG: internal error (unknown)")
|