protobuf-go/proto/decode.go
Joe Tsai 3bc7d6f5cd reflect: switch MessageType.New to return Message
Most usages of New actually prefer to interact with the reflective view
rather than the native Go type. Thus, change New to return that instead.
This parallels reflect.New, which returns the reflective view
(i.e., reflect.Value) instead of native type (i.e., interface{}).
We make the equivalent change to KnownFields.NewMessage, List.NewMessage,
and Map.NewMessage for consistency.

Since this is a subtle change where the type system will not always
catch the changed type, this change was made by both changing the type
and renaming the function to NewXXX and manually looking at every usage
of the the function to ensure that the usage correctly operates
on either the native Go type or the reflective view of the type.
After the entire codebase was cleaned up, a rename was performed to convert
NewXXX back to New.

Change-Id: I153fef627b4bf0a427e4039ce0aaec52e20c7950
Reviewed-on: https://go-review.googlesource.com/c/157077
Reviewed-by: Damien Neil <dneil@google.com>
2019-01-09 20:29:29 +00:00

195 lines
5.5 KiB
Go

// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style.
// license that can be found in the LICENSE file.
package proto
import (
"errors"
"github.com/golang/protobuf/v2/internal/encoding/wire"
"github.com/golang/protobuf/v2/internal/pragma"
"github.com/golang/protobuf/v2/reflect/protoreflect"
)
// UnmarshalOptions configures the unmarshaler.
//
// Example usage:
// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
type UnmarshalOptions struct {
// If DiscardUnknown is set, unknown fields are ignored.
DiscardUnknown bool
pragma.NoUnkeyedLiterals
}
// Unmarshal parses the wire-format message in b and places the result in m.
func Unmarshal(b []byte, m Message) error {
return UnmarshalOptions{}.Unmarshal(b, m)
}
// Unmarshal parses the wire-format message in b and places the result in m.
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
// TODO: Reset m?
return o.unmarshalMessage(b, m.ProtoReflect())
}
func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
messageType := m.Type()
fieldTypes := messageType.Fields()
knownFields := m.KnownFields()
unknownFields := m.UnknownFields()
for len(b) > 0 {
// Parse the tag (field number and wire type).
num, wtyp, tagLen := wire.ConsumeTag(b)
if tagLen < 0 {
return wire.ParseError(tagLen)
}
// Parse the field value.
fieldType := fieldTypes.ByNumber(num)
if fieldType == nil {
fieldType = knownFields.ExtensionTypes().ByNumber(num)
}
var err error
var valLen int
switch {
case fieldType == nil:
err = errUnknown
case fieldType.Cardinality() != protoreflect.Repeated:
valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
case !fieldType.IsMap():
valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
default:
valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
}
if err == errUnknown {
valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
if valLen < 0 {
return wire.ParseError(valLen)
}
unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
} else if err != nil {
return err
}
b = b[tagLen+valLen:]
}
// TODO: required field checks
return nil
}
func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
if err != nil {
return 0, err
}
switch field.Kind() {
case protoreflect.GroupKind, protoreflect.MessageKind:
// Messages are merged with any existing message value,
// unless the message is part of a oneof.
//
// TODO: C++ merges into oneofs, while v1 does not.
// Evaluate which behavior to pick.
var m protoreflect.Message
if knownFields.Has(num) && field.OneofType() == nil {
m = knownFields.Get(num).Message()
} else {
m = knownFields.NewMessage(num)
knownFields.Set(num, protoreflect.ValueOf(m))
}
if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
return 0, err
}
default:
// Non-message scalars replace the previous value.
knownFields.Set(num, v)
}
return n, nil
}
func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
b, n = wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
var (
keyField = field.MessageType().Fields().ByNumber(1)
valField = field.MessageType().Fields().ByNumber(2)
key protoreflect.Value
val protoreflect.Value
haveKey bool
haveVal bool
)
switch valField.Kind() {
case protoreflect.GroupKind, protoreflect.MessageKind:
val = protoreflect.ValueOf(mapv.NewMessage())
}
// Map entries are represented as a two-element message with fields
// containing the key and value.
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
}
b = b[n:]
err = errUnknown
switch num {
case 1:
key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
if err != nil {
break
}
haveKey = true
case 2:
var v protoreflect.Value
v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
if err != nil {
break
}
switch valField.Kind() {
case protoreflect.GroupKind, protoreflect.MessageKind:
if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
return 0, err
}
default:
val = v
}
haveVal = true
}
if err == errUnknown {
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
}
} else if err != nil {
return 0, err
}
b = b[n:]
}
// Every map entry should have entries for key and value, but this is not strictly required.
if !haveKey {
key = keyField.Default()
}
if !haveVal {
switch valField.Kind() {
case protoreflect.GroupKind, protoreflect.MessageKind:
// Trigger required field checks by unmarshaling an empty message.
if err := o.unmarshalMessage(nil, val.Message()); err != nil {
return 0, err
}
default:
val = valField.Default()
}
}
mapv.Set(key.MapKey(), val)
return n, nil
}
// errUnknown is used internally to indicate fields which should be added
// to the unknown field set of a message. It is never returned from an exported
// function.
var errUnknown = errors.New("unknown")