diff --git a/encoding/jsonpb/decode.go b/encoding/jsonpb/decode.go new file mode 100644 index 00000000..a43efd12 --- /dev/null +++ b/encoding/jsonpb/decode.go @@ -0,0 +1,634 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonpb + +import ( + "encoding/base64" + "fmt" + "math" + "strconv" + "strings" + + "github.com/golang/protobuf/v2/internal/encoding/json" + "github.com/golang/protobuf/v2/internal/errors" + "github.com/golang/protobuf/v2/internal/set" + "github.com/golang/protobuf/v2/proto" + pref "github.com/golang/protobuf/v2/reflect/protoreflect" +) + +// Unmarshal reads the given []byte into the given proto.Message. +func Unmarshal(m proto.Message, b []byte) error { + return UnmarshalOptions{}.Unmarshal(m, b) +} + +// UnmarshalOptions is a configurable JSON format parser. +type UnmarshalOptions struct{} + +// Unmarshal reads the given []byte and populates the given proto.Message using +// options in UnmarshalOptions object. It will clear the message first before +// setting the fields. If it returns an error, the given message may be +// partially set. +func (o UnmarshalOptions) Unmarshal(m proto.Message, b []byte) error { + mr := m.ProtoReflect() + // TODO: Determine if we would like to have an option for merging or only + // have merging behavior. We should at least be consistent with textproto + // marshaling. + resetMessage(mr) + + dec := decoder{json.NewDecoder(b)} + var nerr errors.NonFatal + if err := dec.unmarshalMessage(mr); !nerr.Merge(err) { + return err + } + + // Check for EOF. + val, err := dec.Read() + if err != nil { + return err + } + if val.Type() != json.EOF { + return unexpectedJSONError{val} + } + return nerr.E +} + +// resetMessage clears all fields of given protoreflect.Message. +func resetMessage(m pref.Message) { + knownFields := m.KnownFields() + knownFields.Range(func(num pref.FieldNumber, _ pref.Value) bool { + knownFields.Clear(num) + return true + }) + unknownFields := m.UnknownFields() + unknownFields.Range(func(num pref.FieldNumber, _ pref.RawFields) bool { + unknownFields.Set(num, nil) + return true + }) + extTypes := knownFields.ExtensionTypes() + extTypes.Range(func(xt pref.ExtensionType) bool { + extTypes.Remove(xt) + return true + }) +} + +// unexpectedJSONError is an error that contains the unexpected json.Value. This +// is used by decoder methods to provide callers the read json.Value that it +// did not expect. +// TODO: Consider moving this to internal/encoding/json for consistency with +// errors that package returns. +type unexpectedJSONError struct { + value json.Value +} + +func (e unexpectedJSONError) Error() string { + return newError("unexpected value %s", e.value).Error() +} + +// newError returns an error object. If one of the values passed in is of +// json.Value type, it produces an error with position info. +func newError(f string, x ...interface{}) error { + var hasValue bool + var line, column int + for i := 0; i < len(x); i++ { + if val, ok := x[i].(json.Value); ok { + line, column = val.Position() + hasValue = true + break + } + } + e := errors.New(f, x...) + if hasValue { + return errors.New("(line %d:%d): %v", line, column, e) + } + return e +} + +// decoder decodes JSON into protoreflect values. +type decoder struct { + *json.Decoder +} + +// unmarshalMessage unmarshals a message into the given protoreflect.Message. +func (d decoder) unmarshalMessage(m pref.Message) error { + var nerr errors.NonFatal + var reqNums set.Ints + var seenNums set.Ints + + msgType := m.Type() + knownFields := m.KnownFields() + fieldDescs := msgType.Fields() + + jval, err := d.Read() + if !nerr.Merge(err) { + return err + } + if jval.Type() != json.StartObject { + return unexpectedJSONError{jval} + } + +Loop: + for { + // Read field name. + jval, err := d.Read() + if !nerr.Merge(err) { + return err + } + switch jval.Type() { + default: + return unexpectedJSONError{jval} + case json.EndObject: + break Loop + case json.Name: + // Continue below. + } + + name, err := jval.Name() + if !nerr.Merge(err) { + return err + } + + // Get the FieldDescriptor based on the field name. The name can either + // be the JSON name for the field or the proto field name. + fd := fieldDescs.ByJSONName(name) + if fd == nil { + fd = fieldDescs.ByName(pref.Name(name)) + } + + if fd == nil { + // Field is unknown. + // TODO: Provide option to ignore unknown message fields. + return newError("%v contains unknown field %s", msgType.FullName(), jval) + } + + // Do not allow duplicate fields. + num := uint64(fd.Number()) + if seenNums.Has(num) { + return newError("%v contains repeated field %s", msgType.FullName(), jval) + } + seenNums.Set(num) + + // No need to set values for JSON null. + if d.Peek() == json.Null { + d.Read() + continue + } + + if cardinality := fd.Cardinality(); cardinality == pref.Repeated { + // Map or list fields have cardinality of repeated. + if err := d.unmarshalRepeated(fd, knownFields); !nerr.Merge(err) { + return errors.New("%v|%q: %v", fd.FullName(), name, err) + } + } else { + // Required or optional fields. + if err := d.unmarshalSingular(fd, knownFields); !nerr.Merge(err) { + return errors.New("%v|%q: %v", fd.FullName(), name, err) + } + if cardinality == pref.Required { + reqNums.Set(num) + } + } + } + + // Check for any missing required fields. + allReqNums := msgType.RequiredNumbers() + if reqNums.Len() != allReqNums.Len() { + for i := 0; i < allReqNums.Len(); i++ { + if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) { + nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName())) + } + } + } + + return nerr.E +} + +// unmarshalSingular unmarshals to the non-repeated field specified by the given +// FieldDescriptor. +func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, knownFields pref.KnownFields) error { + var val pref.Value + var err error + num := fd.Number() + + switch fd.Kind() { + case pref.MessageKind, pref.GroupKind: + m := knownFields.NewMessage(num) + err = d.unmarshalMessage(m) + val = pref.ValueOf(m) + default: + val, err = d.unmarshalScalar(fd) + } + + var nerr errors.NonFatal + if !nerr.Merge(err) { + return err + } + knownFields.Set(num, val) + return nerr.E +} + +// unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by +// the given FieldDescriptor. +func (d decoder) unmarshalScalar(fd pref.FieldDescriptor) (pref.Value, error) { + const b32 int = 32 + const b64 int = 64 + + var nerr errors.NonFatal + jval, err := d.Read() + if !nerr.Merge(err) { + return pref.Value{}, err + } + + kind := fd.Kind() + switch kind { + case pref.BoolKind: + return unmarshalBool(jval) + + case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind: + return unmarshalInt(jval, b32) + + case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind: + return unmarshalInt(jval, b64) + + case pref.Uint32Kind, pref.Fixed32Kind: + return unmarshalUint(jval, b32) + + case pref.Uint64Kind, pref.Fixed64Kind: + return unmarshalUint(jval, b64) + + case pref.FloatKind: + return unmarshalFloat(jval, b32) + + case pref.DoubleKind: + return unmarshalFloat(jval, b64) + + case pref.StringKind: + pval, err := unmarshalString(jval) + if !nerr.Merge(err) { + return pval, err + } + return pval, nerr.E + + case pref.BytesKind: + return unmarshalBytes(jval) + + case pref.EnumKind: + return unmarshalEnum(jval, fd) + } + + panic(fmt.Sprintf("invalid scalar kind %v", kind)) +} + +func unmarshalBool(jval json.Value) (pref.Value, error) { + if jval.Type() != json.Bool { + return pref.Value{}, unexpectedJSONError{jval} + } + b, err := jval.Bool() + return pref.ValueOf(b), err +} + +func unmarshalInt(jval json.Value, bitSize int) (pref.Value, error) { + switch jval.Type() { + case json.Number: + return getInt(jval, bitSize) + + case json.String: + // Use another decoder to decode number from string. + dec := decoder{json.NewDecoder([]byte(jval.String()))} + var nerr errors.NonFatal + jval, err := dec.Read() + if !nerr.Merge(err) { + return pref.Value{}, err + } + return getInt(jval, bitSize) + } + return pref.Value{}, unexpectedJSONError{jval} +} + +func getInt(jval json.Value, bitSize int) (pref.Value, error) { + n, err := jval.Int(bitSize) + if err != nil { + return pref.Value{}, err + } + if bitSize == 32 { + return pref.ValueOf(int32(n)), nil + } + return pref.ValueOf(n), nil +} + +func unmarshalUint(jval json.Value, bitSize int) (pref.Value, error) { + switch jval.Type() { + case json.Number: + return getUint(jval, bitSize) + + case json.String: + // Use another decoder to decode number from string. + dec := decoder{json.NewDecoder([]byte(jval.String()))} + var nerr errors.NonFatal + jval, err := dec.Read() + if !nerr.Merge(err) { + return pref.Value{}, err + } + return getUint(jval, bitSize) + } + return pref.Value{}, unexpectedJSONError{jval} +} + +func getUint(jval json.Value, bitSize int) (pref.Value, error) { + n, err := jval.Uint(bitSize) + if err != nil { + return pref.Value{}, err + } + if bitSize == 32 { + return pref.ValueOf(uint32(n)), nil + } + return pref.ValueOf(n), nil +} + +func unmarshalFloat(jval json.Value, bitSize int) (pref.Value, error) { + switch jval.Type() { + case json.Number: + return getFloat(jval, bitSize) + + case json.String: + s := jval.String() + switch s { + case "NaN": + if bitSize == 32 { + return pref.ValueOf(float32(math.NaN())), nil + } + return pref.ValueOf(math.NaN()), nil + case "Infinity": + if bitSize == 32 { + return pref.ValueOf(float32(math.Inf(+1))), nil + } + return pref.ValueOf(math.Inf(+1)), nil + case "-Infinity": + if bitSize == 32 { + return pref.ValueOf(float32(math.Inf(-1))), nil + } + return pref.ValueOf(math.Inf(-1)), nil + } + // Use another decoder to decode number from string. + dec := decoder{json.NewDecoder([]byte(s))} + var nerr errors.NonFatal + jval, err := dec.Read() + if !nerr.Merge(err) { + return pref.Value{}, err + } + return getFloat(jval, bitSize) + } + return pref.Value{}, unexpectedJSONError{jval} +} + +func getFloat(jval json.Value, bitSize int) (pref.Value, error) { + n, err := jval.Float(bitSize) + if err != nil { + return pref.Value{}, err + } + if bitSize == 32 { + return pref.ValueOf(float32(n)), nil + } + return pref.ValueOf(n), nil +} + +func unmarshalString(jval json.Value) (pref.Value, error) { + if jval.Type() != json.String { + return pref.Value{}, unexpectedJSONError{jval} + } + return pref.ValueOf(jval.String()), nil +} + +func unmarshalBytes(jval json.Value) (pref.Value, error) { + if jval.Type() != json.String { + return pref.Value{}, unexpectedJSONError{jval} + } + + s := jval.String() + enc := base64.StdEncoding + if strings.ContainsAny(s, "-_") { + enc = base64.URLEncoding + } + if len(s)%4 != 0 { + enc = enc.WithPadding(base64.NoPadding) + } + b, err := enc.DecodeString(s) + if err != nil { + return pref.Value{}, err + } + return pref.ValueOf(b), nil +} + +func unmarshalEnum(jval json.Value, fd pref.FieldDescriptor) (pref.Value, error) { + switch jval.Type() { + case json.String: + // Lookup EnumNumber based on name. + s := jval.String() + if enumVal := fd.EnumType().Values().ByName(pref.Name(s)); enumVal != nil { + return pref.ValueOf(enumVal.Number()), nil + } + return pref.Value{}, newError("invalid enum value %q", jval) + + case json.Number: + n, err := jval.Int(32) + if err != nil { + return pref.Value{}, err + } + return pref.ValueOf(pref.EnumNumber(n)), nil + } + + return pref.Value{}, unexpectedJSONError{jval} +} + +// unmarshalRepeated unmarshals into a repeated field. +func (d decoder) unmarshalRepeated(fd pref.FieldDescriptor, knownFields pref.KnownFields) error { + var nerr errors.NonFatal + num := fd.Number() + val := knownFields.Get(num) + if !fd.IsMap() { + if err := d.unmarshalList(fd, val.List()); !nerr.Merge(err) { + return err + } + } else { + if err := d.unmarshalMap(fd, val.Map()); !nerr.Merge(err) { + return err + } + } + return nerr.E +} + +// unmarshalList unmarshals into given protoreflect.List. +func (d decoder) unmarshalList(fd pref.FieldDescriptor, list pref.List) error { + var nerr errors.NonFatal + jval, err := d.Read() + if !nerr.Merge(err) { + return err + } + if jval.Type() != json.StartArray { + return unexpectedJSONError{jval} + } + + switch fd.Kind() { + case pref.MessageKind, pref.GroupKind: + for { + m := list.NewMessage() + err := d.unmarshalMessage(m) + if !nerr.Merge(err) { + if e, ok := err.(unexpectedJSONError); ok { + if e.value.Type() == json.EndArray { + // Done with list. + return nerr.E + } + } + return err + } + list.Append(pref.ValueOf(m)) + } + default: + for { + val, err := d.unmarshalScalar(fd) + if !nerr.Merge(err) { + if e, ok := err.(unexpectedJSONError); ok { + if e.value.Type() == json.EndArray { + // Done with list. + return nerr.E + } + } + return err + } + list.Append(val) + } + } + return nerr.E +} + +// unmarshalMap unmarshals into given protoreflect.Map. +func (d decoder) unmarshalMap(fd pref.FieldDescriptor, mmap pref.Map) error { + var nerr errors.NonFatal + + jval, err := d.Read() + if !nerr.Merge(err) { + return err + } + if jval.Type() != json.StartObject { + return unexpectedJSONError{jval} + } + + fields := fd.MessageType().Fields() + keyDesc := fields.ByNumber(1) + valDesc := fields.ByNumber(2) + + // Determine ahead whether map entry is a scalar type or a message type in + // order to call the appropriate unmarshalMapValue func inside the for loop + // below. + unmarshalMapValue := func() (pref.Value, error) { + return d.unmarshalScalar(valDesc) + } + switch valDesc.Kind() { + case pref.MessageKind, pref.GroupKind: + unmarshalMapValue = func() (pref.Value, error) { + m := mmap.NewMessage() + if err := d.unmarshalMessage(m); err != nil { + return pref.Value{}, err + } + return pref.ValueOf(m), nil + } + } + +Loop: + for { + // Read field name. + jval, err := d.Read() + if !nerr.Merge(err) { + return err + } + switch jval.Type() { + default: + return unexpectedJSONError{jval} + case json.EndObject: + break Loop + case json.Name: + // Continue. + } + + name, err := jval.Name() + if !nerr.Merge(err) { + return err + } + + // Unmarshal field name. + pkey, err := unmarshalMapKey(name, keyDesc) + if !nerr.Merge(err) { + return err + } + + // Check for duplicate field name. + if mmap.Has(pkey) { + return newError("duplicate map key %q", jval) + } + + // Read and unmarshal field value. + pval, err := unmarshalMapValue() + if !nerr.Merge(err) { + return err + } + + mmap.Set(pkey, pval) + } + + return nerr.E +} + +// unmarshalMapKey converts given string into a protoreflect.MapKey. A map key type is any +// integral or string type. +func unmarshalMapKey(name string, fd pref.FieldDescriptor) (pref.MapKey, error) { + const b32 = 32 + const b64 = 64 + const base10 = 10 + + kind := fd.Kind() + switch kind { + case pref.StringKind: + return pref.ValueOf(name).MapKey(), nil + + case pref.BoolKind: + switch name { + case "true": + return pref.ValueOf(true).MapKey(), nil + case "false": + return pref.ValueOf(false).MapKey(), nil + } + return pref.MapKey{}, errors.New("invalid value for boolean key %q", name) + + case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind: + n, err := strconv.ParseInt(name, base10, b32) + if err != nil { + return pref.MapKey{}, err + } + return pref.ValueOf(int32(n)).MapKey(), nil + + case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind: + n, err := strconv.ParseInt(name, base10, b64) + if err != nil { + return pref.MapKey{}, err + } + return pref.ValueOf(int64(n)).MapKey(), nil + + case pref.Uint32Kind, pref.Fixed32Kind: + n, err := strconv.ParseUint(name, base10, b32) + if err != nil { + return pref.MapKey{}, err + } + return pref.ValueOf(uint32(n)).MapKey(), nil + + case pref.Uint64Kind, pref.Fixed64Kind: + n, err := strconv.ParseUint(name, base10, b64) + if err != nil { + return pref.MapKey{}, err + } + return pref.ValueOf(uint64(n)).MapKey(), nil + } + + panic(fmt.Sprintf("%s: invalid kind %s for map key", fd.FullName(), kind)) +} diff --git a/encoding/jsonpb/decode_test.go b/encoding/jsonpb/decode_test.go new file mode 100644 index 00000000..97a0619c --- /dev/null +++ b/encoding/jsonpb/decode_test.go @@ -0,0 +1,927 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package jsonpb_test + +import ( + "math" + "testing" + + protoV1 "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/v2/encoding/jsonpb" + "github.com/golang/protobuf/v2/encoding/testprotos/pb2" + "github.com/golang/protobuf/v2/encoding/testprotos/pb3" + "github.com/golang/protobuf/v2/internal/scalar" + "github.com/golang/protobuf/v2/proto" +) + +func TestUnmarshal(t *testing.T) { + tests := []struct { + desc string + umo jsonpb.UnmarshalOptions + inputMessage proto.Message + inputText string + wantMessage proto.Message + // TODO: verify expected error message substring. + wantErr bool + }{{ + desc: "proto2 empty message", + inputMessage: &pb2.Scalars{}, + inputText: "{}", + wantMessage: &pb2.Scalars{}, + }, { + desc: "unexpected value instead of EOF", + inputMessage: &pb2.Scalars{}, + inputText: "{} {}", + wantErr: true, + }, { + desc: "proto2 optional scalars set to zero values", + inputMessage: &pb2.Scalars{}, + inputText: `{ + "optBool": false, + "optInt32": 0, + "optInt64": 0, + "optUint32": 0, + "optUint64": 0, + "optSint32": 0, + "optSint64": 0, + "optFixed32": 0, + "optFixed64": 0, + "optSfixed32": 0, + "optSfixed64": 0, + "optFloat": 0, + "optDouble": 0, + "optBytes": "", + "optString": "" +}`, + wantMessage: &pb2.Scalars{ + OptBool: scalar.Bool(false), + OptInt32: scalar.Int32(0), + OptInt64: scalar.Int64(0), + OptUint32: scalar.Uint32(0), + OptUint64: scalar.Uint64(0), + OptSint32: scalar.Int32(0), + OptSint64: scalar.Int64(0), + OptFixed32: scalar.Uint32(0), + OptFixed64: scalar.Uint64(0), + OptSfixed32: scalar.Int32(0), + OptSfixed64: scalar.Int64(0), + OptFloat: scalar.Float32(0), + OptDouble: scalar.Float64(0), + OptBytes: []byte{}, + OptString: scalar.String(""), + }, + }, { + desc: "proto3 scalars set to zero values", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sBool": false, + "sInt32": 0, + "sInt64": 0, + "sUint32": 0, + "sUint64": 0, + "sSint32": 0, + "sSint64": 0, + "sFixed32": 0, + "sFixed64": 0, + "sSfixed32": 0, + "sSfixed64": 0, + "sFloat": 0, + "sDouble": 0, + "sBytes": "", + "sString": "" +}`, + wantMessage: &pb3.Scalars{}, + }, { + desc: "proto2 optional scalars set to null", + inputMessage: &pb2.Scalars{}, + inputText: `{ + "optBool": null, + "optInt32": null, + "optInt64": null, + "optUint32": null, + "optUint64": null, + "optSint32": null, + "optSint64": null, + "optFixed32": null, + "optFixed64": null, + "optSfixed32": null, + "optSfixed64": null, + "optFloat": null, + "optDouble": null, + "optBytes": null, + "optString": null +}`, + wantMessage: &pb2.Scalars{}, + }, { + desc: "proto3 scalars set to null", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sBool": null, + "sInt32": null, + "sInt64": null, + "sUint32": null, + "sUint64": null, + "sSint32": null, + "sSint64": null, + "sFixed32": null, + "sFixed64": null, + "sSfixed32": null, + "sSfixed64": null, + "sFloat": null, + "sDouble": null, + "sBytes": null, + "sString": null +}`, + wantMessage: &pb3.Scalars{}, + }, { + desc: "boolean", + inputMessage: &pb3.Scalars{}, + inputText: `{"sBool": true}`, + wantMessage: &pb3.Scalars{ + SBool: true, + }, + }, { + desc: "not boolean", + inputMessage: &pb3.Scalars{}, + inputText: `{"sBool": "true"}`, + wantErr: true, + }, { + desc: "float and double", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sFloat": 1.234, + "sDouble": 5.678 +}`, + wantMessage: &pb3.Scalars{ + SFloat: 1.234, + SDouble: 5.678, + }, + }, { + desc: "float and double in string", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sFloat": "1.234", + "sDouble": "5.678" +}`, + wantMessage: &pb3.Scalars{ + SFloat: 1.234, + SDouble: 5.678, + }, + }, { + desc: "float and double in E notation", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sFloat": 12.34E-1, + "sDouble": 5.678e4 +}`, + wantMessage: &pb3.Scalars{ + SFloat: 1.234, + SDouble: 56780, + }, + }, { + desc: "float and double in string E notation", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sFloat": "12.34E-1", + "sDouble": "5.678e4" +}`, + wantMessage: &pb3.Scalars{ + SFloat: 1.234, + SDouble: 56780, + }, + }, { + desc: "float exceeds limit", + inputMessage: &pb3.Scalars{}, + inputText: `{"sFloat": 3.4e39}`, + wantErr: true, + }, { + desc: "float in string exceeds limit", + inputMessage: &pb3.Scalars{}, + inputText: `{"sFloat": "-3.4e39"}`, + wantErr: true, + }, { + desc: "double exceeds limit", + inputMessage: &pb3.Scalars{}, + inputText: `{"sFloat": -1.79e+309}`, + wantErr: true, + }, { + desc: "double in string exceeds limit", + inputMessage: &pb3.Scalars{}, + inputText: `{"sFloat": "1.79e+309"}`, + wantErr: true, + }, { + desc: "infinites", + inputMessage: &pb3.Scalars{}, + inputText: `{"sFloat": "Infinity", "sDouble": "-Infinity"}`, + wantMessage: &pb3.Scalars{ + SFloat: float32(math.Inf(+1)), + SDouble: math.Inf(-1), + }, + }, { + desc: "not float", + inputMessage: &pb3.Scalars{}, + inputText: `{"sFloat": true}`, + wantErr: true, + }, { + desc: "not double", + inputMessage: &pb3.Scalars{}, + inputText: `{"sDouble": "not a number"}`, + wantErr: true, + }, { + desc: "integers", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sInt32": 1234, + "sInt64": -1234, + "sUint32": 1e2, + "sUint64": 100E-2, + "sSint32": 1.0, + "sSint64": -1.0, + "sFixed32": 1.234e+5, + "sFixed64": 1200E-2, + "sSfixed32": -1.234e05, + "sSfixed64": -1200e-02 +}`, + wantMessage: &pb3.Scalars{ + SInt32: 1234, + SInt64: -1234, + SUint32: 100, + SUint64: 1, + SSint32: 1, + SSint64: -1, + SFixed32: 123400, + SFixed64: 12, + SSfixed32: -123400, + SSfixed64: -12, + }, + }, { + desc: "integers in string", + inputMessage: &pb3.Scalars{}, + inputText: `{ + "sInt32": "1234", + "sInt64": "-1234", + "sUint32": "1e2", + "sUint64": "100E-2", + "sSint32": "1.0", + "sSint64": "-1.0", + "sFixed32": "1.234e+5", + "sFixed64": "1200E-2", + "sSfixed32": "-1.234e05", + "sSfixed64": "-1200e-02" +}`, + wantMessage: &pb3.Scalars{ + SInt32: 1234, + SInt64: -1234, + SUint32: 100, + SUint64: 1, + SSint32: 1, + SSint64: -1, + SFixed32: 123400, + SFixed64: 12, + SSfixed32: -123400, + SSfixed64: -12, + }, + }, { + desc: "integers in escaped string", + inputMessage: &pb3.Scalars{}, + inputText: `{"sInt32": "\u0031\u0032"}`, + wantMessage: &pb3.Scalars{ + SInt32: 12, + }, + }, { + desc: "number is not an integer", + inputMessage: &pb3.Scalars{}, + inputText: `{"sInt32": 1.001}`, + wantErr: true, + }, { + desc: "32-bit int exceeds limit", + inputMessage: &pb3.Scalars{}, + inputText: `{"sInt32": 2e10}`, + wantErr: true, + }, { + desc: "64-bit int exceeds limit", + inputMessage: &pb3.Scalars{}, + inputText: `{"sSfixed64": -9e19}`, + wantErr: true, + }, { + desc: "not integer", + inputMessage: &pb3.Scalars{}, + inputText: `{"sInt32": "not a number"}`, + wantErr: true, + }, { + desc: "not unsigned integer", + inputMessage: &pb3.Scalars{}, + inputText: `{"sUint32": "not a number"}`, + wantErr: true, + }, { + desc: "number is not an unsigned integer", + inputMessage: &pb3.Scalars{}, + inputText: `{"sUint32": -1}`, + wantErr: true, + }, { + desc: "string", + inputMessage: &pb2.Scalars{}, + inputText: `{"optString": "谷歌"}`, + wantMessage: &pb2.Scalars{ + OptString: scalar.String("谷歌"), + }, + }, { + desc: "string with invalid UTF-8", + inputMessage: &pb3.Scalars{}, + inputText: "{\"sString\": \"\xff\"}", + wantMessage: &pb3.Scalars{ + SString: "\xff", + }, + wantErr: true, + }, { + desc: "not string", + inputMessage: &pb2.Scalars{}, + inputText: `{"optString": 42}`, + wantErr: true, + }, { + desc: "bytes", + inputMessage: &pb3.Scalars{}, + inputText: `{"sBytes": "aGVsbG8gd29ybGQ"}`, + wantMessage: &pb3.Scalars{ + SBytes: []byte("hello world"), + }, + }, { + desc: "bytes padded", + inputMessage: &pb3.Scalars{}, + inputText: `{"sBytes": "aGVsbG8gd29ybGQ="}`, + wantMessage: &pb3.Scalars{ + SBytes: []byte("hello world"), + }, + }, { + desc: "not bytes", + inputMessage: &pb3.Scalars{}, + inputText: `{"sBytes": true}`, + wantErr: true, + }, { + desc: "proto2 enum", + inputMessage: &pb2.Enums{}, + inputText: `{ + "optEnum": "ONE", + "optNestedEnum": "UNO" +}`, + wantMessage: &pb2.Enums{ + OptEnum: pb2.Enum_ONE.Enum(), + OptNestedEnum: pb2.Enums_UNO.Enum(), + }, + }, { + desc: "proto3 enum", + inputMessage: &pb3.Enums{}, + inputText: `{ + "sEnum": "ONE", + "sNestedEnum": "DIEZ" +}`, + wantMessage: &pb3.Enums{ + SEnum: pb3.Enum_ONE, + SNestedEnum: pb3.Enums_DIEZ, + }, + }, { + desc: "enum numeric value", + inputMessage: &pb3.Enums{}, + inputText: `{ + "sEnum": 2, + "sNestedEnum": 2 +}`, + wantMessage: &pb3.Enums{ + SEnum: pb3.Enum_TWO, + SNestedEnum: pb3.Enums_DOS, + }, + }, { + desc: "enum unnamed numeric value", + inputMessage: &pb3.Enums{}, + inputText: `{ + "sEnum": 101, + "sNestedEnum": -101 +}`, + wantMessage: &pb3.Enums{ + SEnum: 101, + SNestedEnum: -101, + }, + }, { + desc: "enum set to number string", + inputMessage: &pb3.Enums{}, + inputText: `{ + "sEnum": "1", +}`, + wantErr: true, + }, { + desc: "enum set to invalid named", + inputMessage: &pb3.Enums{}, + inputText: `{ + "sEnum": "UNNAMED", +}`, + wantErr: true, + }, { + desc: "enum set to not enum", + inputMessage: &pb3.Enums{}, + inputText: `{ + "sEnum": true, +}`, + wantErr: true, + }, { + desc: "proto name", + inputMessage: &pb3.JSONNames{}, + inputText: `{ + "s_string": "proto name used" +}`, + wantMessage: &pb3.JSONNames{ + SString: "proto name used", + }, + }, { + desc: "json_name", + inputMessage: &pb3.JSONNames{}, + inputText: `{ + "foo_bar": "json_name used" +}`, + wantMessage: &pb3.JSONNames{ + SString: "json_name used", + }, + }, { + desc: "camelCase name", + inputMessage: &pb3.JSONNames{}, + inputText: `{ + "sString": "camelcase used" +}`, + wantErr: true, + }, { + desc: "proto name and json_name", + inputMessage: &pb3.JSONNames{}, + inputText: `{ + "foo_bar": "json_name used", + "s_string": "proto name used" +}`, + wantErr: true, + }, { + desc: "duplicate field names", + inputMessage: &pb3.JSONNames{}, + inputText: `{ + "foo_bar": "one", + "foo_bar": "two", +}`, + wantErr: true, + }, { + desc: "null message", + inputMessage: &pb2.Nests{}, + inputText: "null", + wantErr: true, + }, { + desc: "proto2 nested message not set", + inputMessage: &pb2.Nests{}, + inputText: "{}", + wantMessage: &pb2.Nests{}, + }, { + desc: "proto2 nested message set to null", + inputMessage: &pb2.Nests{}, + inputText: `{ + "optNested": null, + "optgroup": null +}`, + wantMessage: &pb2.Nests{}, + }, { + desc: "proto2 nested message set to empty", + inputMessage: &pb2.Nests{}, + inputText: `{ + "optNested": {}, + "optgroup": {} +}`, + wantMessage: &pb2.Nests{ + OptNested: &pb2.Nested{}, + Optgroup: &pb2.Nests_OptGroup{}, + }, + }, { + desc: "proto2 nested messages", + inputMessage: &pb2.Nests{}, + inputText: `{ + "optNested": { + "optString": "nested message", + "optNested": { + "optString": "another nested message" + } + } +}`, + wantMessage: &pb2.Nests{ + OptNested: &pb2.Nested{ + OptString: scalar.String("nested message"), + OptNested: &pb2.Nested{ + OptString: scalar.String("another nested message"), + }, + }, + }, + }, { + desc: "proto2 groups", + inputMessage: &pb2.Nests{}, + inputText: `{ + "optgroup": { + "optString": "inside a group", + "optNested": { + "optString": "nested message inside a group" + }, + "optnestedgroup": { + "optFixed32": 47 + } + } +}`, + wantMessage: &pb2.Nests{ + Optgroup: &pb2.Nests_OptGroup{ + OptString: scalar.String("inside a group"), + OptNested: &pb2.Nested{ + OptString: scalar.String("nested message inside a group"), + }, + Optnestedgroup: &pb2.Nests_OptGroup_OptNestedGroup{ + OptFixed32: scalar.Uint32(47), + }, + }, + }, + }, { + desc: "proto3 nested message not set", + inputMessage: &pb3.Nests{}, + inputText: "{}", + wantMessage: &pb3.Nests{}, + }, { + desc: "proto3 nested message set to null", + inputMessage: &pb3.Nests{}, + inputText: `{"sNested": null}`, + wantMessage: &pb3.Nests{}, + }, { + desc: "proto3 nested message set to empty", + inputMessage: &pb3.Nests{}, + inputText: `{"sNested": {}}`, + wantMessage: &pb3.Nests{ + SNested: &pb3.Nested{}, + }, + }, { + desc: "proto3 nested message", + inputMessage: &pb3.Nests{}, + inputText: `{ + "sNested": { + "sString": "nested message", + "sNested": { + "sString": "another nested message" + } + } +}`, + wantMessage: &pb3.Nests{ + SNested: &pb3.Nested{ + SString: "nested message", + SNested: &pb3.Nested{ + SString: "another nested message", + }, + }, + }, + }, { + desc: "message set to non-message", + inputMessage: &pb3.Nests{}, + inputText: `"not valid"`, + wantErr: true, + }, { + desc: "nested message set to non-message", + inputMessage: &pb3.Nests{}, + inputText: `{"sNested": true}`, + wantErr: true, + }, { + desc: "oneof not set", + inputMessage: &pb3.Oneofs{}, + inputText: "{}", + wantMessage: &pb3.Oneofs{}, + }, { + desc: "oneof set to empty string", + inputMessage: &pb3.Oneofs{}, + inputText: `{"oneofString": ""}`, + wantMessage: &pb3.Oneofs{ + Union: &pb3.Oneofs_OneofString{}, + }, + }, { + desc: "oneof set to string", + inputMessage: &pb3.Oneofs{}, + inputText: `{"oneofString": "hello"}`, + wantMessage: &pb3.Oneofs{ + Union: &pb3.Oneofs_OneofString{ + OneofString: "hello", + }, + }, + }, { + desc: "oneof set to enum", + inputMessage: &pb3.Oneofs{}, + inputText: `{"oneofEnum": "ZERO"}`, + wantMessage: &pb3.Oneofs{ + Union: &pb3.Oneofs_OneofEnum{ + OneofEnum: pb3.Enum_ZERO, + }, + }, + }, { + desc: "oneof set to empty message", + inputMessage: &pb3.Oneofs{}, + inputText: `{"oneofNested": {}}`, + wantMessage: &pb3.Oneofs{ + Union: &pb3.Oneofs_OneofNested{ + OneofNested: &pb3.Nested{}, + }, + }, + }, { + desc: "oneof set to message", + inputMessage: &pb3.Oneofs{}, + inputText: `{ + "oneofNested": { + "sString": "nested message" + } +}`, + wantMessage: &pb3.Oneofs{ + Union: &pb3.Oneofs_OneofNested{ + OneofNested: &pb3.Nested{ + SString: "nested message", + }, + }, + }, + }, { + desc: "repeated null fields", + inputMessage: &pb2.Repeats{}, + inputText: `{ + "rptString": null, + "rptInt32" : null, + "rptFloat" : null, + "rptBytes" : null +}`, + wantMessage: &pb2.Repeats{}, + }, { + desc: "repeated scalars", + inputMessage: &pb2.Repeats{}, + inputText: `{ + "rptString": ["hello", "world"], + "rptInt32" : [-1, 0, 1], + "rptBool" : [false, true] +}`, + wantMessage: &pb2.Repeats{ + RptString: []string{"hello", "world"}, + RptInt32: []int32{-1, 0, 1}, + RptBool: []bool{false, true}, + }, + }, { + desc: "repeated enums", + inputMessage: &pb2.Enums{}, + inputText: `{ + "rptEnum" : ["TEN", 1, 42], + "rptNestedEnum": ["DOS", 2, -47] +}`, + wantMessage: &pb2.Enums{ + RptEnum: []pb2.Enum{pb2.Enum_TEN, pb2.Enum_ONE, 42}, + RptNestedEnum: []pb2.Enums_NestedEnum{pb2.Enums_DOS, pb2.Enums_DOS, -47}, + }, + }, { + desc: "repeated messages", + inputMessage: &pb2.Nests{}, + inputText: `{ + "rptNested": [ + { + "optString": "repeat nested one" + }, + { + "optString": "repeat nested two", + "optNested": { + "optString": "inside repeat nested two" + } + }, + {} + ] +}`, + wantMessage: &pb2.Nests{ + RptNested: []*pb2.Nested{ + { + OptString: scalar.String("repeat nested one"), + }, + { + OptString: scalar.String("repeat nested two"), + OptNested: &pb2.Nested{ + OptString: scalar.String("inside repeat nested two"), + }, + }, + {}, + }, + }, + }, { + desc: "repeated groups", + inputMessage: &pb2.Nests{}, + inputText: `{ + "rptgroup": [ + { + "rptString": ["hello", "world"] + }, + {} + ] +} +`, + wantMessage: &pb2.Nests{ + Rptgroup: []*pb2.Nests_RptGroup{ + { + RptString: []string{"hello", "world"}, + }, + {}, + }, + }, + }, { + desc: "repeated scalars containing invalid type", + inputMessage: &pb2.Repeats{}, + inputText: `{"rptString": ["hello", null, "world"]}`, + wantErr: true, + }, { + desc: "repeated messages containing invalid type", + inputMessage: &pb2.Nests{}, + inputText: `{"rptNested": [{}, null]}`, + wantErr: true, + }, { + desc: "map fields 1", + inputMessage: &pb3.Maps{}, + inputText: `{ + "int32ToStr": { + "-101": "-101", + "0" : "zero", + "255" : "0xff" + }, + "boolToUint32": { + "false": 101, + "true" : "42" + } +}`, + wantMessage: &pb3.Maps{ + Int32ToStr: map[int32]string{ + -101: "-101", + 0xff: "0xff", + 0: "zero", + }, + BoolToUint32: map[bool]uint32{ + true: 42, + false: 101, + }, + }, + }, { + desc: "map fields 2", + inputMessage: &pb3.Maps{}, + inputText: `{ + "uint64ToEnum": { + "1" : "ONE", + "2" : 2, + "10": 101 + } +}`, + wantMessage: &pb3.Maps{ + Uint64ToEnum: map[uint64]pb3.Enum{ + 1: pb3.Enum_ONE, + 2: pb3.Enum_TWO, + 10: 101, + }, + }, + }, { + desc: "map fields 3", + inputMessage: &pb3.Maps{}, + inputText: `{ + "strToNested": { + "nested_one": { + "sString": "nested in a map" + }, + "nested_two": {} + } +}`, + wantMessage: &pb3.Maps{ + StrToNested: map[string]*pb3.Nested{ + "nested_one": { + SString: "nested in a map", + }, + "nested_two": {}, + }, + }, + }, { + desc: "map fields 4", + inputMessage: &pb3.Maps{}, + inputText: `{ + "strToOneofs": { + "nested": { + "oneofNested": { + "sString": "nested oneof in map field value" + } + }, + "string": { + "oneofString": "hello" + } + } +}`, + wantMessage: &pb3.Maps{ + StrToOneofs: map[string]*pb3.Oneofs{ + "string": { + Union: &pb3.Oneofs_OneofString{ + OneofString: "hello", + }, + }, + "nested": { + Union: &pb3.Oneofs_OneofNested{ + OneofNested: &pb3.Nested{ + SString: "nested oneof in map field value", + }, + }, + }, + }, + }, + }, { + desc: "map contains duplicate keys", + inputMessage: &pb3.Maps{}, + inputText: `{ + "int32ToStr": { + "0": "cero", + "0": "zero" + } +} +`, + wantErr: true, + }, { + desc: "map key empty string", + inputMessage: &pb3.Maps{}, + inputText: `{ + "strToNested": { + "": {} + } +}`, + wantMessage: &pb3.Maps{ + StrToNested: map[string]*pb3.Nested{ + "": {}, + }, + }, + }, { + desc: "map contains invalid key 1", + inputMessage: &pb3.Maps{}, + inputText: `{ + "int32ToStr": { + "invalid": "cero" +}`, + wantErr: true, + }, { + desc: "map contains invalid key 2", + inputMessage: &pb3.Maps{}, + inputText: `{ + "int32ToStr": { + "1.02": "float" +}`, + wantErr: true, + }, { + desc: "map contains invalid key 3", + inputMessage: &pb3.Maps{}, + inputText: `{ + "int32ToStr": { + "2147483648": "exceeds 32-bit integer max limit" +}`, + wantErr: true, + }, { + desc: "map contains invalid key 4", + inputMessage: &pb3.Maps{}, + inputText: `{ + "uint64ToEnum": { + "-1": 0 + } +}`, + wantErr: true, + }, { + desc: "map contains invalid value", + inputMessage: &pb3.Maps{}, + inputText: `{ + "int32ToStr": { + "101": true +}`, + wantErr: true, + }, { + desc: "map contains null for scalar value", + inputMessage: &pb3.Maps{}, + inputText: `{ + "int32ToStr": { + "101": null +}`, + wantErr: true, + }, { + desc: "map contains null for message value", + inputMessage: &pb3.Maps{}, + inputText: `{ + "strToNested": { + "hello": null + } +}`, + wantErr: true, + }} + + for _, tt := range tests { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + err := tt.umo.Unmarshal(tt.inputMessage, []byte(tt.inputText)) + if err != nil && !tt.wantErr { + t.Errorf("Unmarshal() returned error: %v\n\n", err) + } + if err == nil && tt.wantErr { + t.Error("Unmarshal() got nil error, want error\n\n") + } + if tt.wantMessage != nil && !protoV1.Equal(tt.inputMessage.(protoV1.Message), tt.wantMessage.(protoV1.Message)) { + t.Errorf("Unmarshal()\n\n%v\n\n%v\n", tt.inputMessage, tt.wantMessage) + } + }) + } +} diff --git a/encoding/jsonpb/doc.go b/encoding/jsonpb/doc.go new file mode 100644 index 00000000..99cf2dd6 --- /dev/null +++ b/encoding/jsonpb/doc.go @@ -0,0 +1,11 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package jsonpb marshals and unmarshals protocol buffer messages as JSON +// format. It follows the guide at +// https://developers.google.com/protocol-buffers/docs/proto3#json. +// +// This package produces a different output than the standard "encoding/json" +// package, which does not operate correctly on protocol buffer messages. +package jsonpb diff --git a/internal/encoding/json/decode.go b/internal/encoding/json/decode.go index 543abbca..452e8736 100644 --- a/internal/encoding/json/decode.go +++ b/internal/encoding/json/decode.go @@ -15,9 +15,24 @@ import ( "github.com/golang/protobuf/v2/internal/errors" ) +// call specifies which Decoder method was invoked. +type call uint8 + +const ( + readCall call = iota + peekCall +) + // Decoder is a token-based JSON decoder. type Decoder struct { - lastType Type + // lastCall is last method called, eiterh readCall or peekCall. + lastCall call + + // value contains the last read value. + value Value + + // err contains the last read error. + err error // startStack is a stack containing StartObject and StartArray types. The // top of stack represents the object or the array the current value is @@ -35,10 +50,24 @@ func NewDecoder(b []byte) *Decoder { return &Decoder{orig: b, in: b} } -// ReadNext returns the next JSON value. It will return an error if there is no -// valid JSON value. For String types containing invalid UTF8 characters, a -// non-fatal error is returned and caller can call ReadNext for the next value. -func (d *Decoder) ReadNext() (Value, error) { +// Peek looks ahead and returns the next JSON type without advancing a read. +func (d *Decoder) Peek() Type { + defer func() { d.lastCall = peekCall }() + if d.lastCall == readCall { + d.value, d.err = d.Read() + } + return d.value.typ +} + +// Read returns the next JSON value. It will return an error if there is no +// valid value. For String types containing invalid UTF8 characters, a +// non-fatal error is returned and caller can call Read for the next value. +func (d *Decoder) Read() (Value, error) { + defer func() { d.lastCall = readCall }() + if d.lastCall == peekCall { + return d.value, d.err + } + var nerr errors.NonFatal value, n, err := d.parseNext() if !nerr.Merge(err) { @@ -48,7 +77,7 @@ func (d *Decoder) ReadNext() (Value, error) { switch value.typ { case EOF: if len(d.startStack) != 0 || - d.lastType&Null|Bool|Number|String|EndObject|EndArray == 0 { + d.value.typ&Null|Bool|Number|String|EndObject|EndArray == 0 { return Value{}, io.ErrUnexpectedEOF } @@ -67,7 +96,7 @@ func (d *Decoder) ReadNext() (Value, error) { break } // Check if this is for an object name. - if d.lastType&(StartObject|comma) == 0 { + if d.value.typ&(StartObject|comma) == 0 { return Value{}, d.newSyntaxError("unexpected value %q", value) } d.in = d.in[n:] @@ -86,7 +115,7 @@ func (d *Decoder) ReadNext() (Value, error) { case EndObject: if len(d.startStack) == 0 || - d.lastType == comma || + d.value.typ == comma || d.startStack[len(d.startStack)-1] != StartObject { return Value{}, d.newSyntaxError("unexpected character }") } @@ -94,7 +123,7 @@ func (d *Decoder) ReadNext() (Value, error) { case EndArray: if len(d.startStack) == 0 || - d.lastType == comma || + d.value.typ == comma || d.startStack[len(d.startStack)-1] != StartArray { return Value{}, d.newSyntaxError("unexpected character ]") } @@ -102,18 +131,18 @@ func (d *Decoder) ReadNext() (Value, error) { case comma: if len(d.startStack) == 0 || - d.lastType&(Null|Bool|Number|String|EndObject|EndArray) == 0 { + d.value.typ&(Null|Bool|Number|String|EndObject|EndArray) == 0 { return Value{}, d.newSyntaxError("unexpected character ,") } } // Update lastType only after validating value to be in the right // sequence. - d.lastType = value.typ + d.value.typ = value.typ d.in = d.in[n:] - if d.lastType == comma { - return d.ReadNext() + if d.value.typ == comma { + return d.Read() } return value, nerr.E } @@ -244,19 +273,19 @@ func (d *Decoder) consume(n int) { // Number, String or Bool. func (d *Decoder) isValueNext() bool { if len(d.startStack) == 0 { - return d.lastType == 0 + return d.value.typ == 0 } start := d.startStack[len(d.startStack)-1] switch start { case StartObject: - return d.lastType&Name != 0 + return d.value.typ&Name != 0 case StartArray: - return d.lastType&(StartArray|comma) != 0 + return d.value.typ&(StartArray|comma) != 0 } panic(fmt.Sprintf( "unreachable logic in Decoder.isValueNext, lastType: %v, startStack: %v", - d.lastType, start)) + d.value.typ, start)) } // newValue constructs a Value. @@ -271,7 +300,7 @@ func (d *Decoder) newValue(typ Type, input []byte, value interface{}) Value { } } -// Value contains a JSON type and value parsed from calling Decoder.ReadNext. +// Value contains a JSON type and value parsed from calling Decoder.Read. type Value struct { input []byte line int diff --git a/internal/encoding/json/decode_test.go b/internal/encoding/json/decode_test.go index 4917bc27..dc7f25fb 100644 --- a/internal/encoding/json/decode_test.go +++ b/internal/encoding/json/decode_test.go @@ -13,9 +13,9 @@ import ( ) type R struct { - // T is expected Type returned from calling Decoder.ReadNext. + // T is expected Type returned from calling Decoder.Read. T json.Type - // E is expected error substring from calling Decoder.ReadNext if set. + // E is expected error substring from calling Decoder.Read if set. E string // V is expected value from calling // Value.{Bool()|Float()|Int()|Uint()|String()} depending on type. @@ -31,8 +31,8 @@ func TestDecoder(t *testing.T) { tests := []struct { input string // want is a list of expected values returned from calling - // Decoder.ReadNext. An item makes the test code invoke - // Decoder.ReadNext and compare against R.T and R.E. For Bool, + // Decoder.Read. An item makes the test code invoke + // Decoder.Read and compare against R.T and R.E. For Bool, // Number and String tokens, it invokes the corresponding getter method // and compares the returned value against R.V or R.VE if it returned an // error. @@ -47,8 +47,8 @@ func TestDecoder(t *testing.T) { want: []R{{T: json.EOF}}, }, { - // Calling ReadNext after EOF will keep returning EOF for - // succeeding ReadNext calls. + // Calling Read after EOF will keep returning EOF for + // succeeding Read calls. input: space, want: []R{ {T: json.EOF}, @@ -119,7 +119,7 @@ func TestDecoder(t *testing.T) { }, }, { - // Invalid UTF-8 error is returned in ReadString instead of ReadNext. + // Invalid UTF-8 error is returned in ReadString instead of Read. input: "\"\xff\"", want: []R{ {T: json.String, E: `invalid UTF-8 detected`, V: string("\xff")}, @@ -1009,22 +1009,26 @@ func TestDecoder(t *testing.T) { t.Run("", func(t *testing.T) { dec := json.NewDecoder([]byte(tc.input)) for i, want := range tc.want { - value, err := dec.ReadNext() + typ := dec.Peek() + if typ != want.T { + t.Errorf("input: %v\nPeek() got %v want %v", tc.input, typ, want.T) + } + value, err := dec.Read() if err != nil { if want.E == "" { - t.Errorf("input: %v\nReadNext() got unexpected error: %v", tc.input, err) + t.Errorf("input: %v\nRead() got unexpected error: %v", tc.input, err) } else if !strings.Contains(err.Error(), want.E) { - t.Errorf("input: %v\nReadNext() got %q, want %q", tc.input, err, want.E) + t.Errorf("input: %v\nRead() got %q, want %q", tc.input, err, want.E) } } else { if want.E != "" { - t.Errorf("input: %v\nReadNext() got nil error, want %q", tc.input, want.E) + t.Errorf("input: %v\nRead() got nil error, want %q", tc.input, want.E) } } token := value.Type() if token != want.T { - t.Errorf("input: %v\nReadNext() got %v, want %v", tc.input, token, want.T) + t.Errorf("input: %v\nRead() got %v, want %v", tc.input, token, want.T) break } checkValue(t, value, i, want) diff --git a/internal/encoding/json/types.go b/internal/encoding/json/types.go index 28901e82..35feeb7a 100644 --- a/internal/encoding/json/types.go +++ b/internal/encoding/json/types.go @@ -8,7 +8,7 @@ package json type Type uint const ( - _ Type = (1 << iota) / 2 + Invalid Type = (1 << iota) / 2 EOF Null Bool