666 lines
17 KiB
Go
Raw Normal View History

// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protojson
import (
"encoding/base64"
"fmt"
"math"
"strconv"
"strings"
"google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/genid"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/internal/set"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
// Unmarshal reads the given []byte into the given [proto.Message].
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func Unmarshal(b []byte, m proto.Message) error {
return UnmarshalOptions{}.Unmarshal(b, m)
}
// UnmarshalOptions is a configurable JSON format parser.
type UnmarshalOptions struct {
pragma.NoUnkeyedLiterals
// If AllowPartial is set, input for messages that will result in missing
// required fields will not return an error.
AllowPartial bool
// If DiscardUnknown is set, unknown fields are ignored.
DiscardUnknown bool
proto, encoding/protojson, encoding/prototext: use Resolver interface Instead of accepting a concrete protoregistry.Types type, accept an interface that provides the necessary functionality to perform the serialization. The advantages of this approach: * There is no need for complex logic to allow a Parent or custom Resolver on the protoregistry.Types type. * Users can pass their own custom resolver implementations directly to the serialization functions. * This is a more principled approach to plumbing custom resolvers than the previous approach of overloading behavior on the concrete Types type. The disadvantages of this approach: * A pointer to a concrete type is 8B, while an interface is 16B. However, the expansion of the {Marshal,Unmarshal}Options structs should be a concern solved separately from how to plumb custom resolvers. * The resolver interfaces as defined today may be insufficient to provide functionality needed in the future if protobuf expands its feature set. For example, let's suppose the Any message permits directly representing a enum by name. This would require the ability to lookup an enum by name. To support that hypothetical need, we can document that the serializers type-assert the provided Resolver to a EnumTypeResolver and use that if possible. There is some loss of type safety with this approach, but provides a clear path forward. Change-Id: I81ca80e59335d36be6b43d57ec8e17abfdfa3bad Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/177044 Reviewed-by: Damien Neil <dneil@google.com>
2019-05-14 14:28:19 -07:00
// Resolver is used for looking up types when unmarshaling
// google.protobuf.Any messages or extension fields.
// If nil, this defaults to using protoregistry.GlobalTypes.
Resolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}
}
// Unmarshal reads the given []byte and populates the given [proto.Message]
// using options in the UnmarshalOptions object.
// It will clear the message first before setting the fields.
// If it returns an error, the given message may be partially set.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
return o.unmarshal(b, m)
}
// unmarshal is a centralized function that all unmarshal operations go through.
// For profiling purposes, avoid changing the name of this function or
// introducing other code paths for unmarshal that do not go through this.
func (o UnmarshalOptions) unmarshal(b []byte, m proto.Message) error {
proto.Reset(m)
if o.Resolver == nil {
o.Resolver = protoregistry.GlobalTypes
}
dec := decoder{json.NewDecoder(b), o}
if err := dec.unmarshalMessage(m.ProtoReflect(), false); err != nil {
return err
}
// Check for EOF.
tok, err := dec.Read()
if err != nil {
return err
}
if tok.Kind() != json.EOF {
return dec.unexpectedTokenError(tok)
}
if o.AllowPartial {
return nil
}
return proto.CheckInitialized(m)
}
type decoder struct {
*json.Decoder
opts UnmarshalOptions
}
// newError returns an error object with position info.
func (d decoder) newError(pos int, f string, x ...interface{}) error {
line, column := d.Position(pos)
head := fmt.Sprintf("(line %d:%d): ", line, column)
return errors.New(head+f, x...)
}
// unexpectedTokenError returns a syntax error for the given unexpected token.
func (d decoder) unexpectedTokenError(tok json.Token) error {
return d.syntaxError(tok.Pos(), "unexpected token %s", tok.RawString())
}
// syntaxError returns a syntax error for given position.
func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
line, column := d.Position(pos)
head := fmt.Sprintf("syntax error (line %d:%d): ", line, column)
return errors.New(head+f, x...)
}
// unmarshalMessage unmarshals a message into the given protoreflect.Message.
func (d decoder) unmarshalMessage(m protoreflect.Message, skipTypeURL bool) error {
if unmarshal := wellKnownTypeUnmarshaler(m.Descriptor().FullName()); unmarshal != nil {
return unmarshal(d, m)
}
tok, err := d.Read()
if err != nil {
return err
}
if tok.Kind() != json.ObjectOpen {
return d.unexpectedTokenError(tok)
}
messageDesc := m.Descriptor()
if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
return errors.New("no support for proto1 MessageSets")
}
var seenNums set.Ints
var seenOneofs set.Ints
fieldDescs := messageDesc.Fields()
for {
// Read field name.
tok, err := d.Read()
if err != nil {
return err
}
switch tok.Kind() {
default:
return d.unexpectedTokenError(tok)
case json.ObjectClose:
return nil
case json.Name:
// Continue below.
}
name := tok.Name()
// Unmarshaling a non-custom embedded message in Any will contain the
// JSON field "@type" which should be skipped because it is not a field
// of the embedded message, but simply an artifact of the Any format.
if skipTypeURL && name == "@type" {
d.Read()
continue
}
// Get the FieldDescriptor.
var fd protoreflect.FieldDescriptor
if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
// Only extension names are in [name] format.
extName := protoreflect.FullName(name[1 : len(name)-1])
extType, err := d.opts.Resolver.FindExtensionByName(extName)
if err != nil && err != protoregistry.NotFound {
return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
}
if extType != nil {
fd = extType.TypeDescriptor()
if !messageDesc.ExtensionRanges().Has(fd.Number()) || fd.ContainingMessage().FullName() != messageDesc.FullName() {
return d.newError(tok.Pos(), "message %v cannot be extended by %v", messageDesc.FullName(), fd.FullName())
}
}
} else {
// The name can either be the JSON name or the proto field name.
fd = fieldDescs.ByJSONName(name)
if fd == nil {
fd = fieldDescs.ByTextName(name)
}
}
if flags.ProtoLegacy {
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
fd = nil // reset since the weak reference is not linked in
}
}
if fd == nil {
// Field is unknown.
if d.opts.DiscardUnknown {
if err := d.skipJSONValue(); err != nil {
return err
}
continue
}
return d.newError(tok.Pos(), "unknown field %v", tok.RawString())
}
// Do not allow duplicate fields.
num := uint64(fd.Number())
if seenNums.Has(num) {
return d.newError(tok.Pos(), "duplicate field %v", tok.RawString())
}
seenNums.Set(num)
// No need to set values for JSON null unless the field type is
// google.protobuf.Value or google.protobuf.NullValue.
if tok, _ := d.Peek(); tok.Kind() == json.Null && !isKnownValue(fd) && !isNullValue(fd) {
d.Read()
continue
}
switch {
case fd.IsList():
list := m.Mutable(fd).List()
if err := d.unmarshalList(list, fd); err != nil {
return err
}
case fd.IsMap():
mmap := m.Mutable(fd).Map()
if err := d.unmarshalMap(mmap, fd); err != nil {
return err
}
default:
// If field is a oneof, check if it has already been set.
if od := fd.ContainingOneof(); od != nil {
idx := uint64(od.Index())
if seenOneofs.Has(idx) {
return d.newError(tok.Pos(), "error parsing %s, oneof %v is already set", tok.RawString(), od.FullName())
}
seenOneofs.Set(idx)
}
// Required or optional fields.
if err := d.unmarshalSingular(m, fd); err != nil {
return err
}
}
}
}
func isKnownValue(fd protoreflect.FieldDescriptor) bool {
md := fd.Message()
return md != nil && md.FullName() == genid.Value_message_fullname
}
func isNullValue(fd protoreflect.FieldDescriptor) bool {
ed := fd.Enum()
return ed != nil && ed.FullName() == genid.NullValue_enum_fullname
}
// unmarshalSingular unmarshals to the non-repeated field specified
// by the given FieldDescriptor.
func (d decoder) unmarshalSingular(m protoreflect.Message, fd protoreflect.FieldDescriptor) error {
var val protoreflect.Value
var err error
switch fd.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
val = m.NewField(fd)
err = d.unmarshalMessage(val.Message(), false)
default:
val, err = d.unmarshalScalar(fd)
}
if err != nil {
return err
}
m.Set(fd, val)
return nil
}
// unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
// the given FieldDescriptor.
func (d decoder) unmarshalScalar(fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
const b32 int = 32
const b64 int = 64
tok, err := d.Read()
if err != nil {
return protoreflect.Value{}, err
}
kind := fd.Kind()
switch kind {
case protoreflect.BoolKind:
if tok.Kind() == json.Bool {
return protoreflect.ValueOfBool(tok.Bool()), nil
}
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
if v, ok := unmarshalInt(tok, b32); ok {
return v, nil
}
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
if v, ok := unmarshalInt(tok, b64); ok {
return v, nil
}
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
if v, ok := unmarshalUint(tok, b32); ok {
return v, nil
}
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
if v, ok := unmarshalUint(tok, b64); ok {
return v, nil
}
case protoreflect.FloatKind:
if v, ok := unmarshalFloat(tok, b32); ok {
return v, nil
}
case protoreflect.DoubleKind:
if v, ok := unmarshalFloat(tok, b64); ok {
return v, nil
}
case protoreflect.StringKind:
if tok.Kind() == json.String {
return protoreflect.ValueOfString(tok.ParsedString()), nil
}
case protoreflect.BytesKind:
if v, ok := unmarshalBytes(tok); ok {
return v, nil
}
case protoreflect.EnumKind:
if v, ok := unmarshalEnum(tok, fd); ok {
return v, nil
}
default:
panic(fmt.Sprintf("unmarshalScalar: invalid scalar kind %v", kind))
}
return protoreflect.Value{}, d.newError(tok.Pos(), "invalid value for %v type: %v", kind, tok.RawString())
}
func unmarshalInt(tok json.Token, bitSize int) (protoreflect.Value, bool) {
switch tok.Kind() {
case json.Number:
return getInt(tok, bitSize)
case json.String:
// Decode number from string.
s := strings.TrimSpace(tok.ParsedString())
if len(s) != len(tok.ParsedString()) {
return protoreflect.Value{}, false
}
dec := json.NewDecoder([]byte(s))
tok, err := dec.Read()
if err != nil {
return protoreflect.Value{}, false
}
return getInt(tok, bitSize)
}
return protoreflect.Value{}, false
}
func getInt(tok json.Token, bitSize int) (protoreflect.Value, bool) {
n, ok := tok.Int(bitSize)
if !ok {
return protoreflect.Value{}, false
}
if bitSize == 32 {
return protoreflect.ValueOfInt32(int32(n)), true
}
return protoreflect.ValueOfInt64(n), true
}
func unmarshalUint(tok json.Token, bitSize int) (protoreflect.Value, bool) {
switch tok.Kind() {
case json.Number:
return getUint(tok, bitSize)
case json.String:
// Decode number from string.
s := strings.TrimSpace(tok.ParsedString())
if len(s) != len(tok.ParsedString()) {
return protoreflect.Value{}, false
}
dec := json.NewDecoder([]byte(s))
tok, err := dec.Read()
if err != nil {
return protoreflect.Value{}, false
}
return getUint(tok, bitSize)
}
return protoreflect.Value{}, false
}
func getUint(tok json.Token, bitSize int) (protoreflect.Value, bool) {
n, ok := tok.Uint(bitSize)
if !ok {
return protoreflect.Value{}, false
}
if bitSize == 32 {
return protoreflect.ValueOfUint32(uint32(n)), true
}
return protoreflect.ValueOfUint64(n), true
}
func unmarshalFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) {
switch tok.Kind() {
case json.Number:
return getFloat(tok, bitSize)
case json.String:
s := tok.ParsedString()
switch s {
case "NaN":
if bitSize == 32 {
return protoreflect.ValueOfFloat32(float32(math.NaN())), true
}
return protoreflect.ValueOfFloat64(math.NaN()), true
case "Infinity":
if bitSize == 32 {
return protoreflect.ValueOfFloat32(float32(math.Inf(+1))), true
}
return protoreflect.ValueOfFloat64(math.Inf(+1)), true
case "-Infinity":
if bitSize == 32 {
return protoreflect.ValueOfFloat32(float32(math.Inf(-1))), true
}
return protoreflect.ValueOfFloat64(math.Inf(-1)), true
}
// Decode number from string.
if len(s) != len(strings.TrimSpace(s)) {
return protoreflect.Value{}, false
}
dec := json.NewDecoder([]byte(s))
tok, err := dec.Read()
if err != nil {
return protoreflect.Value{}, false
}
return getFloat(tok, bitSize)
}
return protoreflect.Value{}, false
}
func getFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) {
n, ok := tok.Float(bitSize)
if !ok {
return protoreflect.Value{}, false
}
if bitSize == 32 {
return protoreflect.ValueOfFloat32(float32(n)), true
}
return protoreflect.ValueOfFloat64(n), true
}
func unmarshalBytes(tok json.Token) (protoreflect.Value, bool) {
if tok.Kind() != json.String {
return protoreflect.Value{}, false
}
s := tok.ParsedString()
enc := base64.StdEncoding
if strings.ContainsAny(s, "-_") {
enc = base64.URLEncoding
}
if len(s)%4 != 0 {
enc = enc.WithPadding(base64.NoPadding)
}
b, err := enc.DecodeString(s)
if err != nil {
return protoreflect.Value{}, false
}
return protoreflect.ValueOfBytes(b), true
}
func unmarshalEnum(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.Value, bool) {
switch tok.Kind() {
case json.String:
// Lookup EnumNumber based on name.
s := tok.ParsedString()
if enumVal := fd.Enum().Values().ByName(protoreflect.Name(s)); enumVal != nil {
return protoreflect.ValueOfEnum(enumVal.Number()), true
}
case json.Number:
if n, ok := tok.Int(32); ok {
return protoreflect.ValueOfEnum(protoreflect.EnumNumber(n)), true
}
case json.Null:
// This is only valid for google.protobuf.NullValue.
if isNullValue(fd) {
return protoreflect.ValueOfEnum(0), true
}
}
return protoreflect.Value{}, false
}
func (d decoder) unmarshalList(list protoreflect.List, fd protoreflect.FieldDescriptor) error {
tok, err := d.Read()
if err != nil {
return err
}
if tok.Kind() != json.ArrayOpen {
return d.unexpectedTokenError(tok)
}
switch fd.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
for {
tok, err := d.Peek()
if err != nil {
return err
}
if tok.Kind() == json.ArrayClose {
d.Read()
return nil
}
val := list.NewElement()
if err := d.unmarshalMessage(val.Message(), false); err != nil {
return err
}
list.Append(val)
}
default:
for {
tok, err := d.Peek()
if err != nil {
return err
}
if tok.Kind() == json.ArrayClose {
d.Read()
return nil
}
val, err := d.unmarshalScalar(fd)
if err != nil {
return err
}
list.Append(val)
}
}
return nil
}
func (d decoder) unmarshalMap(mmap protoreflect.Map, fd protoreflect.FieldDescriptor) error {
tok, err := d.Read()
if err != nil {
return err
}
if tok.Kind() != json.ObjectOpen {
return d.unexpectedTokenError(tok)
}
// Determine ahead whether map entry is a scalar type or a message type in
// order to call the appropriate unmarshalMapValue func inside the for loop
// below.
var unmarshalMapValue func() (protoreflect.Value, error)
switch fd.MapValue().Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
unmarshalMapValue = func() (protoreflect.Value, error) {
val := mmap.NewValue()
if err := d.unmarshalMessage(val.Message(), false); err != nil {
return protoreflect.Value{}, err
}
return val, nil
}
default:
unmarshalMapValue = func() (protoreflect.Value, error) {
return d.unmarshalScalar(fd.MapValue())
}
}
Loop:
for {
// Read field name.
tok, err := d.Read()
if err != nil {
return err
}
switch tok.Kind() {
default:
return d.unexpectedTokenError(tok)
case json.ObjectClose:
break Loop
case json.Name:
// Continue.
}
// Unmarshal field name.
pkey, err := d.unmarshalMapKey(tok, fd.MapKey())
if err != nil {
return err
}
// Check for duplicate field name.
if mmap.Has(pkey) {
return d.newError(tok.Pos(), "duplicate map key %v", tok.RawString())
}
// Read and unmarshal field value.
pval, err := unmarshalMapValue()
if err != nil {
return err
}
mmap.Set(pkey, pval)
}
return nil
}
// unmarshalMapKey converts given token of Name kind into a protoreflect.MapKey.
// A map key type is any integral or string type.
func (d decoder) unmarshalMapKey(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.MapKey, error) {
const b32 = 32
const b64 = 64
const base10 = 10
name := tok.Name()
kind := fd.Kind()
switch kind {
case protoreflect.StringKind:
return protoreflect.ValueOfString(name).MapKey(), nil
case protoreflect.BoolKind:
switch name {
case "true":
return protoreflect.ValueOfBool(true).MapKey(), nil
case "false":
return protoreflect.ValueOfBool(false).MapKey(), nil
}
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
if n, err := strconv.ParseInt(name, base10, b32); err == nil {
return protoreflect.ValueOfInt32(int32(n)).MapKey(), nil
}
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
if n, err := strconv.ParseInt(name, base10, b64); err == nil {
return protoreflect.ValueOfInt64(int64(n)).MapKey(), nil
}
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
if n, err := strconv.ParseUint(name, base10, b32); err == nil {
return protoreflect.ValueOfUint32(uint32(n)).MapKey(), nil
}
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
if n, err := strconv.ParseUint(name, base10, b64); err == nil {
return protoreflect.ValueOfUint64(uint64(n)).MapKey(), nil
}
default:
panic(fmt.Sprintf("invalid kind for map key: %v", kind))
}
return protoreflect.MapKey{}, d.newError(tok.Pos(), "invalid value for %v key: %s", kind, tok.RawString())
}