encoding/jsonpb: add support for marshaling of extensions and messagesets

Change-Id: I839660146760a66c5cbf25d24f81f0ba5096d9e1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167395
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Herbie Ong 2019-03-14 00:01:27 -07:00
parent 99f24c33b0
commit f83d5bb6f0
2 changed files with 308 additions and 16 deletions

View File

@ -13,6 +13,8 @@ import (
"github.com/golang/protobuf/v2/internal/pragma"
"github.com/golang/protobuf/v2/proto"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
descpb "github.com/golang/protobuf/v2/types/descriptor"
)
// Marshal writes the given proto.Message in JSON format using default options.
@ -70,6 +72,7 @@ func (e encoder) marshalMessage(m pref.Message) error {
fieldDescs := m.Type().Fields()
knownFields := m.KnownFields()
// Marshal out known fields.
for i := 0; i < fieldDescs.Len(); i++ {
fd := fieldDescs.Get(i)
num := fd.Number()
@ -92,6 +95,11 @@ func (e encoder) marshalMessage(m pref.Message) error {
return err
}
}
// Marshal out extensions.
if err := e.marshalExtensions(knownFields); !nerr.Merge(err) {
return err
}
return nerr.E
}
@ -222,7 +230,6 @@ func (e encoder) marshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
if err := e.WriteName(entry.key.String()); !nerr.Merge(err) {
return err
}
if err := e.marshalSingular(entry.value, valType); !nerr.Merge(err) {
return err
}
@ -230,22 +237,94 @@ func (e encoder) marshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
return nerr.E
}
// sortMap orders list based on value of key field for deterministic output.
// sortMap orders list based on value of key field for deterministic ordering.
func sortMap(keyKind pref.Kind, values []mapEntry) {
less := func(i, j int) bool {
return values[i].key.String() < values[j].key.String()
}
switch keyKind {
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
less = func(i, j int) bool {
sort.Slice(values, func(i, j int) bool {
switch keyKind {
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
return values[i].key.Int() < values[j].key.Int()
}
case pref.Uint32Kind, pref.Fixed32Kind,
pref.Uint64Kind, pref.Fixed64Kind:
less = func(i, j int) bool {
case pref.Uint32Kind, pref.Fixed32Kind,
pref.Uint64Kind, pref.Fixed64Kind:
return values[i].key.Uint() < values[j].key.Uint()
}
}
sort.Slice(values, less)
return values[i].key.String() < values[j].key.String()
})
}
// marshalExtensions marshals extension fields.
func (e encoder) marshalExtensions(knownFields pref.KnownFields) error {
type xtEntry struct {
key string
value pref.Value
xtType pref.ExtensionType
}
xtTypes := knownFields.ExtensionTypes()
// Get a sorted list based on field key first.
entries := make([]xtEntry, 0, xtTypes.Len())
xtTypes.Range(func(xt pref.ExtensionType) bool {
name := xt.FullName()
// If extended type is a MessageSet, set field name to be the message type name.
if isMessageSetExtension(xt) {
name = xt.MessageType().FullName()
}
num := xt.Number()
if knownFields.Has(num) {
// Use [name] format for JSON field name.
pval := knownFields.Get(num)
entries = append(entries, xtEntry{
key: string(name),
value: pval,
xtType: xt,
})
}
return true
})
// Sort extensions lexicographically.
sort.Slice(entries, func(i, j int) bool {
return entries[i].key < entries[j].key
})
// Write out sorted list.
var nerr errors.NonFatal
for _, entry := range entries {
// JSON field name is the proto field name enclosed in [], similar to
// textproto. This is consistent with Go v1 lib. C++ lib v3.7.0 does not
// marshal out extension fields.
if err := e.WriteName("[" + entry.key + "]"); !nerr.Merge(err) {
return err
}
if err := e.marshalValue(entry.value, entry.xtType); !nerr.Merge(err) {
return err
}
}
return nerr.E
}
// isMessageSetExtension reports whether extension extends a message set.
func isMessageSetExtension(xt pref.ExtensionType) bool {
if xt.Name() != "message_set_extension" {
return false
}
mt := xt.MessageType()
if mt == nil {
return false
}
if xt.FullName().Parent() != mt.FullName() {
return false
}
xmt := xt.ExtendedType()
if xmt.Fields().Len() != 0 {
return false
}
opt := xmt.Options().(*descpb.MessageOptions)
if opt == nil {
return false
}
return opt.GetMessageSetWireFormat()
}

View File

@ -9,13 +9,18 @@ import (
"strings"
"testing"
"github.com/golang/protobuf/protoapi"
"github.com/golang/protobuf/v2/encoding/jsonpb"
"github.com/golang/protobuf/v2/internal/encoding/pack"
"github.com/golang/protobuf/v2/internal/encoding/wire"
"github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
// This legacy package is still needed when importing legacy message.
_ "github.com/golang/protobuf/v2/internal/legacy"
"github.com/golang/protobuf/v2/encoding/testprotos/pb2"
"github.com/golang/protobuf/v2/encoding/testprotos/pb3"
)
@ -37,6 +42,17 @@ func pb2Enums_NestedEnum(i int32) *pb2.Enums_NestedEnum {
return p
}
func setExtension(m proto.Message, xd *protoapi.ExtensionDesc, val interface{}) {
knownFields := m.ProtoReflect().KnownFields()
extTypes := knownFields.ExtensionTypes()
extTypes.Register(xd.Type)
if val == nil {
return
}
pval := xd.Type.ValueOf(val)
knownFields.Set(wire.Number(xd.Field), pval)
}
func TestMarshal(t *testing.T) {
tests := []struct {
desc string
@ -700,13 +716,210 @@ func TestMarshal(t *testing.T) {
},
want: `{
"foo_bar": "json_name"
}`,
}, {
desc: "extensions of non-repeated fields",
input: func() proto.Message {
m := &pb2.Extensions{
OptString: scalar.String("non-extension field"),
OptBool: scalar.Bool(true),
OptInt32: scalar.Int32(42),
}
setExtension(m, pb2.E_OptExtBool, true)
setExtension(m, pb2.E_OptExtString, "extension field")
setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
OptString: scalar.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: scalar.String("another nested in an extension"),
},
})
return m
}(),
want: `{
"optString": "non-extension field",
"optBool": true,
"optInt32": 42,
"[pb2.opt_ext_bool]": true,
"[pb2.opt_ext_enum]": "TEN",
"[pb2.opt_ext_nested]": {
"optString": "nested in an extension",
"optNested": {
"optString": "another nested in an extension"
}
},
"[pb2.opt_ext_string]": "extension field"
}`,
}, {
desc: "extension message field set to nil",
input: func() proto.Message {
m := &pb2.Extensions{}
setExtension(m, pb2.E_OptExtNested, nil)
return m
}(),
want: "{}",
}, {
desc: "extensions of repeated fields",
input: func() proto.Message {
m := &pb2.Extensions{}
setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: scalar.String("one")},
&pb2.Nested{OptString: scalar.String("two")},
&pb2.Nested{OptString: scalar.String("three")},
})
return m
}(),
want: `{
"[pb2.rpt_ext_enum]": [
"TEN",
101,
"ONE"
],
"[pb2.rpt_ext_fixed32]": [
42,
47
],
"[pb2.rpt_ext_nested]": [
{
"optString": "one"
},
{
"optString": "two"
},
{
"optString": "three"
}
]
}`,
}, {
desc: "extensions of non-repeated fields in another message",
input: func() proto.Message {
m := &pb2.Extensions{}
setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
OptString: scalar.String("nested in an extension"),
OptNested: &pb2.Nested{
OptString: scalar.String("another nested in an extension"),
},
})
return m
}(),
want: `{
"[pb2.ExtensionsContainer.opt_ext_bool]": true,
"[pb2.ExtensionsContainer.opt_ext_enum]": "TEN",
"[pb2.ExtensionsContainer.opt_ext_nested]": {
"optString": "nested in an extension",
"optNested": {
"optString": "another nested in an extension"
}
},
"[pb2.ExtensionsContainer.opt_ext_string]": "extension field"
}`,
}, {
desc: "extensions of repeated fields in another message",
input: func() proto.Message {
m := &pb2.Extensions{
OptString: scalar.String("non-extension field"),
OptBool: scalar.Bool(true),
OptInt32: scalar.Int32(42),
}
setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
&pb2.Nested{OptString: scalar.String("one")},
&pb2.Nested{OptString: scalar.String("two")},
&pb2.Nested{OptString: scalar.String("three")},
})
return m
}(),
want: `{
"optString": "non-extension field",
"optBool": true,
"optInt32": 42,
"[pb2.ExtensionsContainer.rpt_ext_enum]": [
"TEN",
101,
"ONE"
],
"[pb2.ExtensionsContainer.rpt_ext_nested]": [
{
"optString": "one"
},
{
"optString": "two"
},
{
"optString": "three"
}
],
"[pb2.ExtensionsContainer.rpt_ext_string]": [
"hello",
"world"
]
}`,
}, {
desc: "MessageSet",
input: func() proto.Message {
m := &pb2.MessageSet{}
setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
OptString: scalar.String("a messageset extension"),
})
setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
OptString: scalar.String("not a messageset extension"),
})
setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
OptString: scalar.String("just a regular extension"),
})
return m
}(),
want: `{
"[pb2.MessageSetExtension]": {
"optString": "a messageset extension"
},
"[pb2.MessageSetExtension.ext_nested]": {
"optString": "just a regular extension"
},
"[pb2.MessageSetExtension.not_message_set_extension]": {
"optString": "not a messageset extension"
}
}`,
}, {
desc: "not real MessageSet 1",
input: func() proto.Message {
m := &pb2.FakeMessageSet{}
setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: scalar.String("not a messageset extension"),
})
return m
}(),
want: `{
"[pb2.FakeMessageSetExtension.message_set_extension]": {
"optString": "not a messageset extension"
}
}`,
}, {
desc: "not real MessageSet 2",
input: func() proto.Message {
m := &pb2.MessageSet{}
setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
OptString: scalar.String("another not a messageset extension"),
})
return m
}(),
want: `{
"[pb2.message_set_extension]": {
"optString": "another not a messageset extension"
}
}`,
}}
for _, tt := range tests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
t.Parallel()
b, err := tt.mo.Marshal(tt.input)
if err != nil {
t.Errorf("Marshal() returned error: %v\n", err)