diff --git a/encoding/jsonpb/decode.go b/encoding/jsonpb/decode.go index a43efd12..5ea6fb85 100644 --- a/encoding/jsonpb/decode.go +++ b/encoding/jsonpb/decode.go @@ -13,9 +13,11 @@ import ( "github.com/golang/protobuf/v2/internal/encoding/json" "github.com/golang/protobuf/v2/internal/errors" + "github.com/golang/protobuf/v2/internal/pragma" "github.com/golang/protobuf/v2/internal/set" "github.com/golang/protobuf/v2/proto" pref "github.com/golang/protobuf/v2/reflect/protoreflect" + "github.com/golang/protobuf/v2/reflect/protoregistry" ) // Unmarshal reads the given []byte into the given proto.Message. @@ -24,7 +26,14 @@ func Unmarshal(m proto.Message, b []byte) error { } // UnmarshalOptions is a configurable JSON format parser. -type UnmarshalOptions struct{} +type UnmarshalOptions struct { + pragma.NoUnkeyedLiterals + + // Resolver is the registry used for type lookups when unmarshaling extensions + // and processing Any. If Resolver is not set, unmarshaling will default to + // using protoregistry.GlobalTypes. + Resolver *protoregistry.Types +} // Unmarshal reads the given []byte and populates the given proto.Message using // options in UnmarshalOptions object. It will clear the message first before @@ -37,7 +46,15 @@ func (o UnmarshalOptions) Unmarshal(m proto.Message, b []byte) error { // marshaling. resetMessage(mr) - dec := decoder{json.NewDecoder(b)} + resolver := o.Resolver + if resolver == nil { + resolver = protoregistry.GlobalTypes + } + + dec := decoder{ + Decoder: json.NewDecoder(b), + resolver: resolver, + } var nerr errors.NonFatal if err := dec.unmarshalMessage(mr); !nerr.Merge(err) { return err @@ -108,6 +125,7 @@ func newError(f string, x ...interface{}) error { // decoder decodes JSON into protoreflect values. type decoder struct { *json.Decoder + resolver *protoregistry.Types } // unmarshalMessage unmarshals a message into the given protoreflect.Message. @@ -119,6 +137,7 @@ func (d decoder) unmarshalMessage(m pref.Message) error { msgType := m.Type() knownFields := m.KnownFields() fieldDescs := msgType.Fields() + xtTypes := knownFields.ExtensionTypes() jval, err := d.Read() if !nerr.Merge(err) { @@ -149,11 +168,28 @@ Loop: return err } - // Get the FieldDescriptor based on the field name. The name can either - // be the JSON name for the field or the proto field name. - fd := fieldDescs.ByJSONName(name) - if fd == nil { - fd = fieldDescs.ByName(pref.Name(name)) + // Get the FieldDescriptor. + var fd pref.FieldDescriptor + if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") { + // Only extension names are in [name] format. + xtName := pref.FullName(name[1 : len(name)-1]) + xt := xtTypes.ByName(xtName) + if xt == nil { + xt, err = d.findExtension(xtName) + if err != nil && err != protoregistry.NotFound { + return errors.New("unable to resolve [%v]: %v", xtName, err) + } + if xt != nil { + xtTypes.Register(xt) + } + } + fd = xt + } else { + // The name can either be the JSON name or the proto field name. + fd = fieldDescs.ByJSONName(name) + if fd == nil { + fd = fieldDescs.ByName(pref.Name(name)) + } } if fd == nil { @@ -204,6 +240,21 @@ Loop: return nerr.E } +// findExtension returns protoreflect.ExtensionType from the resolver if found. +func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) { + xt, err := d.resolver.FindExtensionByName(xtName) + if err == nil { + return xt, nil + } + + // Check if this is a MessageSet extension field. + xt, err = d.resolver.FindExtensionByName(xtName + ".message_set_extension") + if err == nil && isMessageSetExtension(xt) { + return xt, nil + } + return nil, protoregistry.NotFound +} + // unmarshalSingular unmarshals to the non-repeated field specified by the given // FieldDescriptor. func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, knownFields pref.KnownFields) error { @@ -294,8 +345,8 @@ func unmarshalInt(jval json.Value, bitSize int) (pref.Value, error) { return getInt(jval, bitSize) case json.String: - // Use another decoder to decode number from string. - dec := decoder{json.NewDecoder([]byte(jval.String()))} + // Decode number from string. + dec := json.NewDecoder([]byte(jval.String())) var nerr errors.NonFatal jval, err := dec.Read() if !nerr.Merge(err) { @@ -323,8 +374,8 @@ func unmarshalUint(jval json.Value, bitSize int) (pref.Value, error) { return getUint(jval, bitSize) case json.String: - // Use another decoder to decode number from string. - dec := decoder{json.NewDecoder([]byte(jval.String()))} + // Decode number from string. + dec := json.NewDecoder([]byte(jval.String())) var nerr errors.NonFatal jval, err := dec.Read() if !nerr.Merge(err) { @@ -370,8 +421,8 @@ func unmarshalFloat(jval json.Value, bitSize int) (pref.Value, error) { } return pref.ValueOf(math.Inf(-1)), nil } - // Use another decoder to decode number from string. - dec := decoder{json.NewDecoder([]byte(s))} + // Decode number from string. + dec := json.NewDecoder([]byte(s)) var nerr errors.NonFatal jval, err := dec.Read() if !nerr.Merge(err) { diff --git a/encoding/jsonpb/decode_test.go b/encoding/jsonpb/decode_test.go index 97a0619c..2ef81cd5 100644 --- a/encoding/jsonpb/decode_test.go +++ b/encoding/jsonpb/decode_test.go @@ -9,13 +9,43 @@ import ( "testing" protoV1 "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoapi" "github.com/golang/protobuf/v2/encoding/jsonpb" "github.com/golang/protobuf/v2/encoding/testprotos/pb2" "github.com/golang/protobuf/v2/encoding/testprotos/pb3" "github.com/golang/protobuf/v2/internal/scalar" "github.com/golang/protobuf/v2/proto" + preg "github.com/golang/protobuf/v2/reflect/protoregistry" ) +func init() { + // TODO: remove these registerExtension calls when generated code registers + // to V2 global registry. + registerExtension(pb2.E_OptExtBool) + registerExtension(pb2.E_OptExtString) + registerExtension(pb2.E_OptExtEnum) + registerExtension(pb2.E_OptExtNested) + registerExtension(pb2.E_RptExtFixed32) + registerExtension(pb2.E_RptExtEnum) + registerExtension(pb2.E_RptExtNested) + registerExtension(pb2.E_ExtensionsContainer_OptExtBool) + registerExtension(pb2.E_ExtensionsContainer_OptExtString) + registerExtension(pb2.E_ExtensionsContainer_OptExtEnum) + registerExtension(pb2.E_ExtensionsContainer_OptExtNested) + registerExtension(pb2.E_ExtensionsContainer_RptExtString) + registerExtension(pb2.E_ExtensionsContainer_RptExtEnum) + registerExtension(pb2.E_ExtensionsContainer_RptExtNested) + registerExtension(pb2.E_MessageSetExtension) + registerExtension(pb2.E_MessageSetExtension_MessageSetExtension) + registerExtension(pb2.E_MessageSetExtension_NotMessageSetExtension) + registerExtension(pb2.E_MessageSetExtension_ExtNested) + registerExtension(pb2.E_FakeMessageSetExtension_MessageSetExtension) +} + +func registerExtension(xd *protoapi.ExtensionDesc) { + preg.GlobalTypes.Register(xd.Type) +} + func TestUnmarshal(t *testing.T) { tests := []struct { desc string @@ -907,6 +937,215 @@ func TestUnmarshal(t *testing.T) { } }`, wantErr: true, + }, { + desc: "extensions of non-repeated fields", + inputMessage: &pb2.Extensions{}, + inputText: `{ + "optString": "non-extension field", + "optBool": true, + "optInt32": 42, + "[pb2.opt_ext_bool]": true, + "[pb2.opt_ext_nested]": { + "optString": "nested in an extension", + "opt_nested": { + "opt_string": "another nested in an extension" + } + }, + "[pb2.opt_ext_string]": "extension field", + "[pb2.opt_ext_enum]": "TEN" +}`, + wantMessage: func() proto.Message { + m := &pb2.Extensions{ + OptString: scalar.String("non-extension field"), + OptBool: scalar.Bool(true), + OptInt32: scalar.Int32(42), + } + setExtension(m, pb2.E_OptExtBool, true) + setExtension(m, pb2.E_OptExtString, "extension field") + setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) + setExtension(m, pb2.E_OptExtNested, &pb2.Nested{ + OptString: scalar.String("nested in an extension"), + OptNested: &pb2.Nested{ + OptString: scalar.String("another nested in an extension"), + }, + }) + return m + }(), + }, { + desc: "extensions of repeated fields", + inputMessage: &pb2.Extensions{}, + inputText: `{ + "[pb2.rpt_ext_enum]": ["TEN", 101, "ONE"], + "[pb2.rpt_ext_fixed32]": [42, 47], + "[pb2.rpt_ext_nested]": [ + {"optString": "one"}, + {"optString": "two"}, + {"optString": "three"} + ] +}`, + wantMessage: func() proto.Message { + m := &pb2.Extensions{} + setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) + setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ + &pb2.Nested{OptString: scalar.String("one")}, + &pb2.Nested{OptString: scalar.String("two")}, + &pb2.Nested{OptString: scalar.String("three")}, + }) + return m + }(), + }, { + desc: "extensions of non-repeated fields in another message", + inputMessage: &pb2.Extensions{}, + inputText: `{ + "[pb2.ExtensionsContainer.opt_ext_bool]": true, + "[pb2.ExtensionsContainer.opt_ext_enum]": "TEN", + "[pb2.ExtensionsContainer.opt_ext_nested]": { + "optString": "nested in an extension", + "optNested": { + "optString": "another nested in an extension" + } + }, + "[pb2.ExtensionsContainer.opt_ext_string]": "extension field" +}`, + wantMessage: func() proto.Message { + m := &pb2.Extensions{} + setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) + setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") + setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) + setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ + OptString: scalar.String("nested in an extension"), + OptNested: &pb2.Nested{ + OptString: scalar.String("another nested in an extension"), + }, + }) + return m + }(), + }, { + desc: "extensions of repeated fields in another message", + inputMessage: &pb2.Extensions{}, + inputText: `{ + "optString": "non-extension field", + "optBool": true, + "optInt32": 42, + "[pb2.ExtensionsContainer.rpt_ext_nested]": [ + {"optString": "one"}, + {"optString": "two"}, + {"optString": "three"} + ], + "[pb2.ExtensionsContainer.rpt_ext_enum]": ["TEN", 101, "ONE"], + "[pb2.ExtensionsContainer.rpt_ext_string]": ["hello", "world"] +}`, + wantMessage: func() proto.Message { + m := &pb2.Extensions{ + OptString: scalar.String("non-extension field"), + OptBool: scalar.Bool(true), + OptInt32: scalar.Int32(42), + } + setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) + setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ + &pb2.Nested{OptString: scalar.String("one")}, + &pb2.Nested{OptString: scalar.String("two")}, + &pb2.Nested{OptString: scalar.String("three")}, + }) + return m + }(), + }, { + desc: "invalid extension field name", + inputMessage: &pb2.Extensions{}, + inputText: `{ "[pb2.invalid_message_field]": true }`, + wantErr: true, + }, { + desc: "MessageSet", + inputMessage: &pb2.MessageSet{}, + inputText: `{ + "[pb2.MessageSetExtension]": { + "optString": "a messageset extension" + }, + "[pb2.MessageSetExtension.ext_nested]": { + "optString": "just a regular extension" + }, + "[pb2.MessageSetExtension.not_message_set_extension]": { + "optString": "not a messageset extension" + } +}`, + wantMessage: func() proto.Message { + m := &pb2.MessageSet{} + setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ + OptString: scalar.String("a messageset extension"), + }) + setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ + OptString: scalar.String("not a messageset extension"), + }) + setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ + OptString: scalar.String("just a regular extension"), + }) + return m + }(), + }, { + desc: "extension field set to null", + inputMessage: &pb2.Extensions{}, + inputText: `{ + "[pb2.ExtensionsContainer.opt_ext_bool]": null, + "[pb2.ExtensionsContainer.opt_ext_nested]": null +}`, + wantMessage: func() proto.Message { + m := &pb2.Extensions{} + setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, nil) + setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, nil) + return m + }(), + }, { + desc: "extensions of repeated field contains null", + inputMessage: &pb2.Extensions{}, + inputText: `{ + "[pb2.ExtensionsContainer.rpt_ext_nested]": [ + {"optString": "one"}, + null, + {"optString": "three"} + ], +}`, + wantErr: true, + }, { + desc: "not real MessageSet 1", + inputMessage: &pb2.FakeMessageSet{}, + inputText: `{ + "[pb2.FakeMessageSetExtension.message_set_extension]": { + "optString": "not a messageset extension" + } +}`, + wantMessage: func() proto.Message { + m := &pb2.FakeMessageSet{} + setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ + OptString: scalar.String("not a messageset extension"), + }) + return m + }(), + }, { + desc: "not real MessageSet 2", + inputMessage: &pb2.FakeMessageSet{}, + inputText: `{ + "[pb2.FakeMessageSetExtension]": { + "optString": "not a messageset extension" + } +}`, + wantErr: true, + }, { + desc: "not real MessageSet 3", + inputMessage: &pb2.MessageSet{}, + inputText: `{ + "[pb2.message_set_extension]": { + "optString": "another not a messageset extension" + } +}`, + wantMessage: func() proto.Message { + m := &pb2.MessageSet{} + setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ + OptString: scalar.String("another not a messageset extension"), + }) + return m + }(), }} for _, tt := range tests {