internal/detectknown: add helper package to identify well-known types

Change-Id: Id54621b4b44522a350e6994074962852690b5d66
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/225257
Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
Joe Tsai 2020-03-24 11:46:34 -07:00
parent f8d77f810a
commit d037755d51
4 changed files with 177 additions and 97 deletions

View File

@ -11,6 +11,7 @@ import (
"strings"
"time"
"google.golang.org/protobuf/internal/detectknown"
"google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/fieldnum"
@ -23,27 +24,18 @@ import (
// The list of custom types here has to match the ones in marshalCustomType and
// unmarshalCustomType.
func isCustomType(name pref.FullName) bool {
switch name {
case "google.protobuf.Any",
"google.protobuf.BoolValue",
"google.protobuf.DoubleValue",
"google.protobuf.FloatValue",
"google.protobuf.Int32Value",
"google.protobuf.Int64Value",
"google.protobuf.UInt32Value",
"google.protobuf.UInt64Value",
"google.protobuf.StringValue",
"google.protobuf.BytesValue",
"google.protobuf.Empty",
"google.protobuf.Struct",
"google.protobuf.ListValue",
"google.protobuf.Value",
"google.protobuf.Duration",
"google.protobuf.Timestamp",
"google.protobuf.FieldMask":
return true
switch detectknown.Which(name) {
case detectknown.AnyProto:
case detectknown.TimestampProto:
case detectknown.DurationProto:
case detectknown.WrappersProto:
case detectknown.StructProto:
case detectknown.FieldMaskProto:
case detectknown.EmptyProto:
default:
return false
}
return false
return true
}
// marshalCustomType marshals given well-known type message that have special
@ -51,44 +43,24 @@ func isCustomType(name pref.FullName) bool {
// returns true, else it will panic.
func (e encoder) marshalCustomType(m pref.Message) error {
name := m.Descriptor().FullName()
switch name {
case "google.protobuf.Any":
switch detectknown.Which(name) {
case detectknown.AnyProto:
return e.marshalAny(m)
case "google.protobuf.BoolValue",
"google.protobuf.DoubleValue",
"google.protobuf.FloatValue",
"google.protobuf.Int32Value",
"google.protobuf.Int64Value",
"google.protobuf.UInt32Value",
"google.protobuf.UInt64Value",
"google.protobuf.StringValue",
"google.protobuf.BytesValue":
return e.marshalWrapperType(m)
case "google.protobuf.Empty":
return e.marshalEmpty(m)
case "google.protobuf.Struct":
return e.marshalStruct(m)
case "google.protobuf.ListValue":
return e.marshalListValue(m)
case "google.protobuf.Value":
return e.marshalKnownValue(m)
case "google.protobuf.Duration":
return e.marshalDuration(m)
case "google.protobuf.Timestamp":
case detectknown.TimestampProto:
return e.marshalTimestamp(m)
case "google.protobuf.FieldMask":
case detectknown.DurationProto:
return e.marshalDuration(m)
case detectknown.WrappersProto:
return e.marshalWrapperType(m)
case detectknown.StructProto:
return e.marshalStructType(m)
case detectknown.FieldMaskProto:
return e.marshalFieldMask(m)
case detectknown.EmptyProto:
return e.marshalEmpty(m)
default:
panic(fmt.Sprintf("%s does not have a custom marshaler", name))
}
panic(fmt.Sprintf("%s does not have a custom marshaler", name))
}
// unmarshalCustomType unmarshals given well-known type message that have
@ -96,44 +68,24 @@ func (e encoder) marshalCustomType(m pref.Message) error {
// isCustomType returns true, else it will panic.
func (d decoder) unmarshalCustomType(m pref.Message) error {
name := m.Descriptor().FullName()
switch name {
case "google.protobuf.Any":
switch detectknown.Which(name) {
case detectknown.AnyProto:
return d.unmarshalAny(m)
case "google.protobuf.BoolValue",
"google.protobuf.DoubleValue",
"google.protobuf.FloatValue",
"google.protobuf.Int32Value",
"google.protobuf.Int64Value",
"google.protobuf.UInt32Value",
"google.protobuf.UInt64Value",
"google.protobuf.StringValue",
"google.protobuf.BytesValue":
return d.unmarshalWrapperType(m)
case "google.protobuf.Empty":
return d.unmarshalEmpty(m)
case "google.protobuf.Struct":
return d.unmarshalStruct(m)
case "google.protobuf.ListValue":
return d.unmarshalListValue(m)
case "google.protobuf.Value":
return d.unmarshalKnownValue(m)
case "google.protobuf.Duration":
return d.unmarshalDuration(m)
case "google.protobuf.Timestamp":
case detectknown.TimestampProto:
return d.unmarshalTimestamp(m)
case "google.protobuf.FieldMask":
case detectknown.DurationProto:
return d.unmarshalDuration(m)
case detectknown.WrappersProto:
return d.unmarshalWrapperType(m)
case detectknown.StructProto:
return d.unmarshalStructType(m)
case detectknown.FieldMaskProto:
return d.unmarshalFieldMask(m)
case detectknown.EmptyProto:
return d.unmarshalEmpty(m)
default:
panic(fmt.Sprintf("%s does not have a custom unmarshaler", name))
}
panic(fmt.Sprintf("%s does not have a custom unmarshaler", name))
}
// The JSON representation of an Any message uses the regular representation of
@ -501,6 +453,32 @@ func (d decoder) unmarshalEmpty(pref.Message) error {
}
}
func (e encoder) marshalStructType(m pref.Message) error {
switch m.Descriptor().Name() {
case "Struct":
return e.marshalStruct(m)
case "ListValue":
return e.marshalListValue(m)
case "Value":
return e.marshalKnownValue(m)
default:
panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
}
}
func (d decoder) unmarshalStructType(m pref.Message) error {
switch m.Descriptor().Name() {
case "Struct":
return d.unmarshalStruct(m)
case "ListValue":
return d.unmarshalListValue(m)
case "Value":
return d.unmarshalKnownValue(m)
default:
panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
}
}
// The JSON representation for Struct is a JSON object that contains the encoded
// Struct.fields map and follows the serialization rules for a map.

View File

@ -0,0 +1,47 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package detectknown provides functionality for detecting well-known types
// and identifying them by name.
package detectknown
import "google.golang.org/protobuf/reflect/protoreflect"
type ProtoFile int
const (
Unknown ProtoFile = iota
AnyProto
TimestampProto
DurationProto
WrappersProto
StructProto
FieldMaskProto
EmptyProto
)
var wellKnownTypes = map[protoreflect.FullName]ProtoFile{
"google.protobuf.Any": AnyProto,
"google.protobuf.Timestamp": TimestampProto,
"google.protobuf.Duration": DurationProto,
"google.protobuf.BoolValue": WrappersProto,
"google.protobuf.Int32Value": WrappersProto,
"google.protobuf.Int64Value": WrappersProto,
"google.protobuf.UInt32Value": WrappersProto,
"google.protobuf.UInt64Value": WrappersProto,
"google.protobuf.FloatValue": WrappersProto,
"google.protobuf.DoubleValue": WrappersProto,
"google.protobuf.BytesValue": WrappersProto,
"google.protobuf.StringValue": WrappersProto,
"google.protobuf.Struct": StructProto,
"google.protobuf.ListValue": StructProto,
"google.protobuf.Value": StructProto,
"google.protobuf.FieldMask": FieldMaskProto,
"google.protobuf.Empty": EmptyProto,
}
// Which identifies the proto file that a well-known type belongs to.
func Which(s protoreflect.FullName) ProtoFile {
return wellKnownTypes[s]
}

View File

@ -0,0 +1,58 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package detectknown_test
import (
"testing"
"google.golang.org/protobuf/internal/detectknown"
"google.golang.org/protobuf/reflect/protoreflect"
fieldmaskpb "google.golang.org/protobuf/internal/testprotos/fieldmaskpb"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
"google.golang.org/protobuf/types/pluginpb"
)
func TestWhich(t *testing.T) {
tests := []struct {
in protoreflect.FileDescriptor
want detectknown.ProtoFile
}{
{descriptorpb.File_google_protobuf_descriptor_proto, detectknown.Unknown},
{pluginpb.File_google_protobuf_compiler_plugin_proto, detectknown.Unknown},
{anypb.File_google_protobuf_any_proto, detectknown.AnyProto},
{timestamppb.File_google_protobuf_timestamp_proto, detectknown.TimestampProto},
{durationpb.File_google_protobuf_duration_proto, detectknown.DurationProto},
{wrapperspb.File_google_protobuf_wrappers_proto, detectknown.WrappersProto},
{structpb.File_google_protobuf_struct_proto, detectknown.StructProto},
{fieldmaskpb.File_google_protobuf_field_mask_proto, detectknown.FieldMaskProto},
{emptypb.File_google_protobuf_empty_proto, detectknown.EmptyProto},
}
for _, tt := range tests {
rangeMessages(tt.in.Messages(), func(md protoreflect.MessageDescriptor) {
got := detectknown.Which(md.FullName())
if got != tt.want {
t.Errorf("Which(%s) = %v, want %v", md.FullName(), got, tt.want)
}
})
}
}
func rangeMessages(mds protoreflect.MessageDescriptors, f func(protoreflect.MessageDescriptor)) {
for i := 0; i < mds.Len(); i++ {
md := mds.Get(i)
if !md.IsMapEntry() {
f(md)
}
rangeMessages(md.Messages(), f)
}
}

View File

@ -18,6 +18,7 @@ import (
"time"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/internal/detectknown"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/proto"
@ -102,13 +103,9 @@ var protocmpMessageType = reflect.TypeOf(map[string]interface{}(nil))
func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
md := m.Descriptor()
if md.FullName().Parent() != "google.protobuf" {
return nil
}
fds := md.Fields()
switch md.Name() {
case "Any":
switch detectknown.Which(md.FullName()) {
case detectknown.AnyProto:
var msgVal protoreflect.Message
url := m.Get(fds.ByName("type_url")).String()
if v := reflect.ValueOf(m); v.Type().ConvertibleTo(protocmpMessageType) {
@ -140,7 +137,7 @@ func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
b = append(b, '}')
return b
case "Timestamp":
case detectknown.TimestampProto:
secs := m.Get(fds.ByName("seconds")).Int()
nanos := m.Get(fds.ByName("nanos")).Int()
if nanos < 0 || nanos >= 1e9 {
@ -153,7 +150,7 @@ func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
x = strings.TrimSuffix(x, ".000")
return append(b, x+"Z"...)
case "Duration":
case detectknown.DurationProto:
secs := m.Get(fds.ByName("seconds")).Int()
nanos := m.Get(fds.ByName("nanos")).Int()
if nanos <= -1e9 || nanos >= 1e9 || (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) {
@ -165,7 +162,7 @@ func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
x = strings.TrimSuffix(x, ".000")
return append(b, x+"s"...)
case "BoolValue", "Int32Value", "Int64Value", "UInt32Value", "UInt64Value", "FloatValue", "DoubleValue", "StringValue", "BytesValue":
case detectknown.WrappersProto:
fd := fds.ByName("value")
return appendValue(b, m.Get(fd), fd)
}