internal/impl: support legacy extension fields

Implement support for extension fields for messages that use the v1
data structures for extensions. The legacyExtensionFields type wraps a
v1 map to implement the v2 protoreflect.KnownFields interface.

Working on this change revealed a bug in the dynamic construction of
message types for protobuf messages that had cyclic dependencies (e.g.,
message Foo has a sub-field of message Bar, and Bar has a sub-field of Foo).
In such a situation, a deadlock occurs because initialization code depends on
the very initialization code that is currently running. To break these cycles,
we make some systematic changes listed in the following paragraphs.
Generally speaking, we separate the logic for construction and wrapping,
where constuction does not recursively rely on dependencies,
while wrapping may recursively inspect dependencies.

Promote the MessageType.MessageOf method as a standalone MessageOf function
that dynamically finds the proper *MessageType to use. We make it such that
MessageType only supports two forms of messages types:
* Those that fully implement the v2 API.
* Those that do not implement the v2 API at all.
This removes support for the hybrid form that was exploited by message_test.go

In impl/message_test.go, switch each message to look more like how future
generated messages will look like. This is done in reaction to the fact that
MessageType.MessageOf no longer exists.

In value/{map,vector}.go, fix Unwrap to return a pointer since the underlying
reflect.Value is addressable reference value, not a pointer value.

In value/convert.go, split the logic apart so that obtaining a v2 type and
wrapping a type as v2 are distinct operations. Wrapping requires further
initialization than simply creating the initial message type, and calling it
during initial construction would lead to a deadlock.

In protoreflect/go_type.go, we switch back to a lazy initialization of GoType
to avoid a deadlock since the user-provided fn may rely on the fact that
prototype.GoMessage returned.

Change-Id: I5dea00e36fe1a9899bd2ac0aed2c8e51d5d87420
Reviewed-on: https://go-review.googlesource.com/c/148826
Reviewed-by: Herbie Ong <herbie@google.com>
This commit is contained in:
Joe Tsai 2018-11-06 13:05:20 -08:00 committed by Joe Tsai
parent ea11813c05
commit f0c01e459b
15 changed files with 1137 additions and 346 deletions

View File

@ -16,29 +16,36 @@ import (
ptype "github.com/golang/protobuf/v2/reflect/prototype"
)
// legacyWrapEnum wraps v as a protoreflect.ProtoEnum,
// where v must be a *struct kind and not implement the v2 API already.
func legacyWrapEnum(v reflect.Value) pref.ProtoEnum {
et := legacyLoadEnumType(v.Type())
return et.New(pref.EnumNumber(v.Int()))
}
var enumTypeCache sync.Map // map[reflect.Type]protoreflect.EnumType
// wrapLegacyEnum wraps v as a protoreflect.ProtoEnum,
// where v must be an int32 kind and not implement the v2 API already.
func wrapLegacyEnum(v reflect.Value) pref.ProtoEnum {
// legacyLoadEnumType dynamically loads a protoreflect.EnumType for t,
// where t must be an int32 kind and not implement the v2 API already.
func legacyLoadEnumType(t reflect.Type) pref.EnumType {
// Fast-path: check if a EnumType is cached for this concrete type.
if et, ok := enumTypeCache.Load(v.Type()); ok {
return et.(pref.EnumType).New(pref.EnumNumber(v.Int()))
if et, ok := enumTypeCache.Load(t); ok {
return et.(pref.EnumType)
}
// Slow-path: derive enum descriptor and initialize EnumType.
var m sync.Map // map[protoreflect.EnumNumber]proto.Enum
ed := loadEnumDesc(v.Type())
ed := legacyLoadEnumDesc(t)
et := ptype.GoEnum(ed, func(et pref.EnumType, n pref.EnumNumber) pref.ProtoEnum {
if e, ok := m.Load(n); ok {
return e.(pref.ProtoEnum)
}
e := &legacyEnumWrapper{num: n, pbTyp: et, goTyp: v.Type()}
e := &legacyEnumWrapper{num: n, pbTyp: et, goTyp: t}
m.Store(n, e)
return e
})
enumTypeCache.Store(v.Type(), et)
return et.(pref.EnumType).New(pref.EnumNumber(v.Int()))
enumTypeCache.Store(t, et)
return et.(pref.EnumType)
}
type legacyEnumWrapper struct {
@ -70,9 +77,11 @@ var (
var enumDescCache sync.Map // map[reflect.Type]protoreflect.EnumDescriptor
// loadEnumDesc returns an EnumDescriptor derived from the Go type,
var enumNumberType = reflect.TypeOf(pref.EnumNumber(0))
// legacyLoadEnumDesc returns an EnumDescriptor derived from the Go type,
// which must be an int32 kind and not implement the v2 API already.
func loadEnumDesc(t reflect.Type) pref.EnumDescriptor {
func legacyLoadEnumDesc(t reflect.Type) pref.EnumDescriptor {
// Fast-path: check if an EnumDescriptor is cached for this concrete type.
if v, ok := enumDescCache.Load(t); ok {
return v.(pref.EnumDescriptor)
@ -82,6 +91,9 @@ func loadEnumDesc(t reflect.Type) pref.EnumDescriptor {
if t.Kind() != reflect.Int32 || t.PkgPath() == "" {
panic(fmt.Sprintf("got %v, want named int32 kind", t))
}
if t == enumNumberType {
panic(fmt.Sprintf("cannot be %v", t))
}
// Derive the enum descriptor from the raw descriptor proto.
e := new(ptype.StandaloneEnum)
@ -91,7 +103,7 @@ func loadEnumDesc(t reflect.Type) pref.EnumDescriptor {
}
if ed, ok := ev.(legacyEnum); ok {
b, idxs := ed.EnumDescriptor()
fd := loadFileDesc(b)
fd := legacyLoadFileDesc(b)
// Derive syntax.
switch fd.GetSyntax() {

View File

@ -0,0 +1,353 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"fmt"
"reflect"
protoV1 "github.com/golang/protobuf/proto"
ptag "github.com/golang/protobuf/v2/internal/encoding/tag"
pvalue "github.com/golang/protobuf/v2/internal/value"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
ptype "github.com/golang/protobuf/v2/reflect/prototype"
)
func makeLegacyExtensionFieldsFunc(t reflect.Type) func(p *messageDataType) pref.KnownFields {
f := makeLegacyExtensionMapFunc(t)
if f == nil {
return nil
}
return func(p *messageDataType) pref.KnownFields {
return legacyExtensionFields{p.mi, f(p)}
}
}
type legacyExtensionFields struct {
mi *MessageType
x legacyExtensionIface
}
func (p legacyExtensionFields) Len() (n int) {
p.x.Range(func(num pref.FieldNumber, _ legacyExtensionEntry) bool {
if p.Has(num) {
n++
}
return true
})
return n
}
func (p legacyExtensionFields) Has(n pref.FieldNumber) bool {
x := p.x.Get(n)
if x.val == nil {
return false
}
t := legacyExtensionTypeOf(x.desc)
if t.Cardinality() == pref.Repeated {
return legacyExtensionValueOf(x.val, t).Vector().Len() > 0
}
return true
}
func (p legacyExtensionFields) Get(n pref.FieldNumber) pref.Value {
x := p.x.Get(n)
if x.desc == nil {
return pref.Value{}
}
t := legacyExtensionTypeOf(x.desc)
if x.val == nil {
if t.Cardinality() == pref.Repeated {
// TODO: What is the zero value for Vectors?
// TODO: This logic is racy.
v := t.ValueOf(t.New())
x.val = legacyExtensionInterfaceOf(v, t)
p.x.Set(n, x)
return v
}
if t.Kind() == pref.MessageKind || t.Kind() == pref.GroupKind {
// TODO: What is the zero value for Messages?
return pref.Value{}
}
return t.Default()
}
return legacyExtensionValueOf(x.val, t)
}
func (p legacyExtensionFields) Set(n pref.FieldNumber, v pref.Value) {
x := p.x.Get(n)
if x.desc == nil {
panic("no extension descriptor registered")
}
t := legacyExtensionTypeOf(x.desc)
x.val = legacyExtensionInterfaceOf(v, t)
p.x.Set(n, x)
}
func (p legacyExtensionFields) Clear(n pref.FieldNumber) {
x := p.x.Get(n)
if x.desc == nil {
return
}
x.val = nil
p.x.Set(n, x)
}
func (p legacyExtensionFields) Mutable(n pref.FieldNumber) pref.Mutable {
x := p.x.Get(n)
if x.desc == nil {
panic("no extension descriptor registered")
}
t := legacyExtensionTypeOf(x.desc)
if x.val == nil {
v := t.ValueOf(t.New())
x.val = legacyExtensionInterfaceOf(v, t)
p.x.Set(n, x)
}
return legacyExtensionValueOf(x.val, t).Interface().(pref.Mutable)
}
func (p legacyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
p.x.Range(func(n pref.FieldNumber, x legacyExtensionEntry) bool {
if p.Has(n) {
return f(n, p.Get(n))
}
return true
})
}
func (p legacyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes {
return legacyExtensionTypes(p)
}
type legacyExtensionTypes legacyExtensionFields
func (p legacyExtensionTypes) Len() (n int) {
p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
if x.desc != nil {
n++
}
return true
})
return n
}
func (p legacyExtensionTypes) Register(t pref.ExtensionType) {
if p.mi.Type.FullName() != t.ExtendedType().FullName() {
panic("extended type mismatch")
}
if !p.mi.Type.ExtensionRanges().Has(t.Number()) {
panic("invalid extension field number")
}
x := p.x.Get(t.Number())
if x.desc != nil {
panic("extension descriptor already registered")
}
x.desc = legacyExtensionDescOf(t, p.mi.goType)
p.x.Set(t.Number(), x)
}
func (p legacyExtensionTypes) Remove(t pref.ExtensionType) {
if !p.mi.Type.ExtensionRanges().Has(t.Number()) {
return
}
x := p.x.Get(t.Number())
if x.val != nil {
panic("value for extension descriptor still populated")
}
x.desc = nil
if len(x.raw) == 0 {
p.x.Clear(t.Number())
} else {
p.x.Set(t.Number(), x)
}
}
func (p legacyExtensionTypes) ByNumber(n pref.FieldNumber) pref.ExtensionType {
x := p.x.Get(n)
if x.desc != nil {
return legacyExtensionTypeOf(x.desc)
}
return nil
}
func (p legacyExtensionTypes) ByName(s pref.FullName) (t pref.ExtensionType) {
p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
if x.desc != nil && x.desc.Name == string(s) {
t = legacyExtensionTypeOf(x.desc)
return false
}
return true
})
return t
}
func (p legacyExtensionTypes) Range(f func(pref.ExtensionType) bool) {
p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
if x.desc != nil {
if !f(legacyExtensionTypeOf(x.desc)) {
return false
}
}
return true
})
}
func legacyExtensionDescOf(t pref.ExtensionType, parent reflect.Type) *protoV1.ExtensionDesc {
if t, ok := t.(*legacyExtensionType); ok {
return t.desc
}
// Determine the v1 extension type, which is unfortunately not the same as
// the v2 ExtensionType.GoType.
extType := t.GoType()
switch extType.Kind() {
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
extType = reflect.PtrTo(extType) // T -> *T for singular scalar fields
case reflect.Ptr:
if extType.Elem().Kind() == reflect.Slice {
extType = extType.Elem() // *[]T -> []T for repeated fields
}
}
// Reconstruct the legacy enum full name, which is an odd mixture of the
// proto package name with the Go type name.
var enumName string
if t.Kind() == pref.EnumKind {
enumName = t.GoType().Name()
for d, ok := pref.Descriptor(t.EnumType()), true; ok; d, ok = d.Parent() {
if fd, _ := d.(pref.FileDescriptor); fd != nil && fd.Package() != "" {
enumName = string(fd.Package()) + "." + enumName
}
}
}
// Construct and return a v1 ExtensionDesc.
return &protoV1.ExtensionDesc{
ExtendedType: reflect.Zero(parent).Interface().(protoV1.Message),
ExtensionType: reflect.Zero(extType).Interface(),
Field: int32(t.Number()),
Name: string(t.FullName()),
Tag: ptag.Marshal(t, enumName),
}
}
func legacyExtensionTypeOf(d *protoV1.ExtensionDesc) pref.ExtensionType {
// TODO: Add a field to protoV1.ExtensionDesc to contain a v2 descriptor.
// Derive basic field information from the struct tag.
t := reflect.TypeOf(d.ExtensionType)
isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
if isOptional || isRepeated {
t = t.Elem()
}
f := ptag.Unmarshal(d.Tag, t)
// Construct a v2 ExtensionType.
conv := newConverter(t, f.Kind)
xd, err := ptype.NewExtension(&ptype.StandaloneExtension{
FullName: pref.FullName(d.Name),
Number: pref.FieldNumber(d.Field),
Cardinality: f.Cardinality,
Kind: f.Kind,
Default: f.Default,
Options: f.Options,
EnumType: conv.EnumType,
MessageType: conv.MessageType,
ExtendedType: legacyLoadMessageDesc(reflect.TypeOf(d.ExtendedType)),
})
if err != nil {
panic(err)
}
xt := ptype.GoExtension(xd, conv.EnumType, conv.MessageType)
// Return the extension type as is if the dependencies already support v2.
xt2 := &legacyExtensionType{ExtensionType: xt, desc: d}
if !conv.IsLegacy {
return xt2
}
// If the dependency is a v1 enum or message, we need to create a custom
// extension type where ExtensionType.GoType continues to use the legacy
// v1 Go type, instead of the wrapped versions that satisfy the v2 API.
if xd.Cardinality() != pref.Repeated {
// Custom extension type for singular enums and messages.
// The legacy wrappers use legacyEnumWrapper and legacyMessageWrapper
// to implement the v2 interfaces for enums and messages.
// Both of those type satisfy the value.Unwrapper interface.
xt2.typ = t
xt2.new = func() interface{} {
return xt.New().(pvalue.Unwrapper).Unwrap()
}
xt2.valueOf = func(v interface{}) pref.Value {
if reflect.TypeOf(v) != xt2.typ {
panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
}
if xd.Kind() == pref.EnumKind {
return xt.ValueOf(legacyWrapEnum(reflect.ValueOf(v)))
} else {
return xt.ValueOf(legacyWrapMessage(reflect.ValueOf(v)))
}
}
xt2.interfaceOf = func(v pref.Value) interface{} {
return xt.InterfaceOf(v).(pvalue.Unwrapper).Unwrap()
}
} else {
// Custom extension type for repeated enums and messages.
xt2.typ = reflect.PtrTo(reflect.SliceOf(t))
xt2.new = func() interface{} {
return reflect.New(xt2.typ.Elem()).Interface()
}
xt2.valueOf = func(v interface{}) pref.Value {
if reflect.TypeOf(v) != xt2.typ {
panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
}
return pref.ValueOf(pvalue.VectorOf(v, conv))
}
xt2.interfaceOf = func(pv pref.Value) interface{} {
v := pv.Vector().(pvalue.Unwrapper).Unwrap()
if reflect.TypeOf(v) != xt2.typ {
panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
}
return v
}
}
return xt2
}
type legacyExtensionType struct {
pref.ExtensionType
desc *protoV1.ExtensionDesc
typ reflect.Type
new func() interface{}
valueOf func(interface{}) pref.Value
interfaceOf func(pref.Value) interface{}
}
func (x *legacyExtensionType) GoType() reflect.Type {
if x.typ != nil {
return x.typ
}
return x.ExtensionType.GoType()
}
func (x *legacyExtensionType) New() interface{} {
if x.new != nil {
return x.new()
}
return x.ExtensionType.New()
}
func (x *legacyExtensionType) ValueOf(v interface{}) pref.Value {
if x.valueOf != nil {
return x.valueOf(v)
}
return x.ExtensionType.ValueOf(v)
}
func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} {
if x.interfaceOf != nil {
return x.interfaceOf(v)
}
return x.ExtensionType.InterfaceOf(v)
}

View File

@ -13,10 +13,13 @@ import (
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
)
// TODO: The logic below this is a hack since v1 currently exposes no
// exported functionality for interacting with these data structures.
// Eventually make changes to v1 such that v2 can access the necessary
// fields without relying on unsafe.
// TODO: The logic in the file is a hack and should be in the v1 repository.
// We need to break the dependency on proto v1 since it is v1 that will
// eventually need to depend on v2.
// TODO: The v1 API currently exposes no exported functionality for interacting
// with the extension data structures. We will need to make changes in v1 so
// that v2 can access these data structures without relying on unsafe.
var (
extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil))
@ -25,8 +28,10 @@ var (
type legacyExtensionIface interface {
Len() int
Has(pref.FieldNumber) bool
Get(pref.FieldNumber) legacyExtensionEntry
Set(pref.FieldNumber, legacyExtensionEntry)
Clear(pref.FieldNumber)
Range(f func(pref.FieldNumber, legacyExtensionEntry) bool)
}
@ -49,6 +54,13 @@ func makeLegacyExtensionMapFunc(t reflect.Type) func(*messageDataType) legacyExt
}
}
// TODO: We currently don't do locking with legacyExtensionSyncMap.p.mu.
// The locking behavior was already obscure "feature" beforehand,
// and it is not obvious how it translates to the v2 API.
// The v2 API presents a Range method, which calls a user provided function,
// which may in turn call other methods on the map. In such a use case,
// acquiring a lock within each method would result in a reentrant deadlock.
// legacyExtensionSyncMap is identical to protoV1.XXX_InternalExtensions.
// It implements legacyExtensionIface.
type legacyExtensionSyncMap struct {
@ -62,16 +74,15 @@ func (m legacyExtensionSyncMap) Len() int {
if m.p == nil {
return 0
}
m.p.mu.Lock()
defer m.p.mu.Unlock()
return m.p.m.Len()
}
func (m legacyExtensionSyncMap) Has(n pref.FieldNumber) bool {
return m.p.m.Has(n)
}
func (m legacyExtensionSyncMap) Get(n pref.FieldNumber) legacyExtensionEntry {
if m.p == nil {
return legacyExtensionEntry{}
}
m.p.mu.Lock()
defer m.p.mu.Unlock()
return m.p.m.Get(n)
}
func (m *legacyExtensionSyncMap) Set(n pref.FieldNumber, x legacyExtensionEntry) {
@ -81,16 +92,15 @@ func (m *legacyExtensionSyncMap) Set(n pref.FieldNumber, x legacyExtensionEntry)
m legacyExtensionMap
})
}
m.p.mu.Lock()
defer m.p.mu.Unlock()
m.p.m.Set(n, x)
}
func (m legacyExtensionSyncMap) Clear(n pref.FieldNumber) {
m.p.m.Clear(n)
}
func (m legacyExtensionSyncMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) {
if m.p == nil {
return
}
m.p.mu.Lock()
defer m.p.mu.Unlock()
m.p.m.Range(f)
}
@ -101,6 +111,10 @@ type legacyExtensionMap map[pref.FieldNumber]legacyExtensionEntry
func (m legacyExtensionMap) Len() int {
return len(m)
}
func (m legacyExtensionMap) Has(n pref.FieldNumber) bool {
_, ok := m[n]
return ok
}
func (m legacyExtensionMap) Get(n pref.FieldNumber) legacyExtensionEntry {
return m[n]
}
@ -110,6 +124,9 @@ func (m *legacyExtensionMap) Set(n pref.FieldNumber, x legacyExtensionEntry) {
}
(*m)[n] = x
}
func (m *legacyExtensionMap) Clear(n pref.FieldNumber) {
delete(*m, n)
}
func (m legacyExtensionMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) {
for n, x := range m {
if !f(n, x) {
@ -124,3 +141,76 @@ type legacyExtensionEntry struct {
val interface{}
raw []byte
}
// TODO: The legacyExtensionInterfaceOf and legacyExtensionValueOf converters
// exist since the current storage representation in the v1 data structures use
// *T for scalars and []T for repeated fields, but the v2 API operates on
// T for scalars and *[]T for repeated fields.
//
// Instead of maintaining this technical debt in the v2 repository,
// we can offload this into the v1 implementation such that it uses a
// storage representation that is appropriate for v2, and uses the these
// functions to present the illusion that that the underlying storage
// is still *T and []T.
//
// See https://github.com/golang/protobuf/pull/746
const hasPR746 = true
// legacyExtensionInterfaceOf converts a protoreflect.Value to the
// storage representation used in v1 extension data structures.
//
// In particular, it represents scalars (except []byte) a pointer to the value,
// and repeated fields as the a slice value itself.
func legacyExtensionInterfaceOf(pv pref.Value, t pref.ExtensionType) interface{} {
v := t.InterfaceOf(pv)
if !hasPR746 {
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
// Represent primitive types as a pointer to the value.
rv2 := reflect.New(rv.Type())
rv2.Elem().Set(rv)
v = rv2.Interface()
case reflect.Ptr:
// Represent pointer to slice types as the value itself.
switch rv.Type().Elem().Kind() {
case reflect.Slice:
if rv.IsNil() {
v = reflect.Zero(rv.Type().Elem()).Interface()
} else {
v = rv.Elem().Interface()
}
}
}
}
return v
}
// legacyExtensionValueOf converts the storage representation of a value in
// the v1 extension data structures to a protoreflect.Value.
//
// In particular, it represents scalars as the value itself,
// and repeated fields as a pointer to the slice value.
func legacyExtensionValueOf(v interface{}, t pref.ExtensionType) pref.Value {
if !hasPR746 {
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Ptr:
// Represent slice types as the value itself.
switch rv.Type().Elem().Kind() {
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
if rv.IsNil() {
v = reflect.Zero(rv.Type().Elem()).Interface()
} else {
v = rv.Elem().Interface()
}
}
case reflect.Slice:
// Represent slice types (except []byte) as a pointer to the value.
if rv.Type().Elem().Kind() != reflect.Uint8 {
rv2 := reflect.New(rv.Type())
rv2.Elem().Set(rv)
v = rv2.Interface()
}
}
}
return t.ValueOf(v)
}

View File

@ -36,13 +36,13 @@ type (
var fileDescCache sync.Map // map[*byte]*descriptorV1.FileDescriptorProto
// loadFileDesc unmarshals b as a compressed FileDescriptorProto message.
// legacyLoadFileDesc unmarshals b as a compressed FileDescriptorProto message.
//
// This assumes that b is immutable and that b does not refer to part of a
// concatenated series of GZIP files (which would require shenanigans that
// rely on the concatenation properties of both protobufs and GZIP).
// File descriptors generated by protoc-gen-go do not rely on that property.
func loadFileDesc(b []byte) *descriptorV1.FileDescriptorProto {
func legacyLoadFileDesc(b []byte) *descriptorV1.FileDescriptorProto {
// Fast-path: check whether we already have a cached file descriptor.
if v, ok := fileDescCache.Load(&b[0]); ok {
return v.(*descriptorV1.FileDescriptorProto)

View File

@ -19,26 +19,38 @@ import (
ptype "github.com/golang/protobuf/v2/reflect/prototype"
)
// legacyWrapMessage wraps v as a protoreflect.ProtoMessage,
// where v must be a *struct kind and not implement the v2 API already.
func legacyWrapMessage(v reflect.Value) pref.ProtoMessage {
mt := legacyLoadMessageType(v.Type())
return (*legacyMessageWrapper)(mt.dataTypeOf(v.Interface()))
}
var messageTypeCache sync.Map // map[reflect.Type]*MessageType
// wrapLegacyMessage wraps v as a protoreflect.ProtoMessage,
// where v must be a *struct kind and not implement the v2 API already.
func wrapLegacyMessage(v reflect.Value) pref.ProtoMessage {
// legacyLoadMessageType dynamically loads a *MessageType for t,
// where t must be a *struct kind and not implement the v2 API already.
func legacyLoadMessageType(t reflect.Type) *MessageType {
// Fast-path: check if a MessageType is cached for this concrete type.
if mt, ok := messageTypeCache.Load(v.Type()); ok {
return mt.(*MessageType).MessageOf(v.Interface()).Interface()
if mt, ok := messageTypeCache.Load(t); ok {
return mt.(*MessageType)
}
// Slow-path: derive message descriptor and initialize MessageType.
mt := &MessageType{Desc: loadMessageDesc(v.Type())}
messageTypeCache.Store(v.Type(), mt)
return mt.MessageOf(v.Interface()).Interface()
md := legacyLoadMessageDesc(t)
mt := new(MessageType)
mt.Type = ptype.GoMessage(md, func(pref.MessageType) pref.ProtoMessage {
p := reflect.New(t.Elem()).Interface()
return (*legacyMessageWrapper)(mt.dataTypeOf(p))
})
messageTypeCache.Store(t, mt)
return mt
}
type legacyMessageWrapper messageDataType
func (m *legacyMessageWrapper) Type() pref.MessageType {
return m.mi.pbType
return m.mi.Type
}
func (m *legacyMessageWrapper) KnownFields() pref.KnownFields {
return (*knownFields)(m)
@ -65,9 +77,9 @@ var (
var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
// loadMessageDesc returns an MessageDescriptor derived from the Go type,
// legacyLoadMessageDesc returns an MessageDescriptor derived from the Go type,
// which must be a *struct kind and not implement the v2 API already.
func loadMessageDesc(t reflect.Type) pref.MessageDescriptor {
func legacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
return messageDescSet{}.Load(t)
}
@ -126,7 +138,7 @@ func (ms *messageDescSet) processMessage(t reflect.Type) pref.MessageDescriptor
}
if md, ok := mv.(legacyMessage); ok {
b, idxs := md.Descriptor()
fd := loadFileDesc(b)
fd := legacyLoadFileDesc(b)
// Derive syntax.
switch fd.GetSyntax() {
@ -231,7 +243,7 @@ func (ms *messageDescSet) parseField(tag, tagKey, tagVal string, goType reflect.
if ev, ok := reflect.Zero(t).Interface().(pref.ProtoEnum); ok {
f.EnumType = ev.ProtoReflect().Type()
} else {
f.EnumType = loadEnumDesc(t)
f.EnumType = legacyLoadEnumDesc(t)
}
}
if f.MessageType == nil && (f.Kind == pref.MessageKind || f.Kind == pref.GroupKind) {

View File

@ -30,7 +30,7 @@ import (
)
func mustLoadFileDesc(b []byte, _ []int) pref.FileDescriptor {
fd, err := ptype.NewFileFromDescriptorProto(loadFileDesc(b), nil)
fd, err := ptype.NewFileFromDescriptorProto(legacyLoadFileDesc(b), nil)
if err != nil {
panic(err)
}
@ -42,286 +42,286 @@ func TestLegacyDescriptor(t *testing.T) {
fileDescP2_20160225 := mustLoadFileDesc(new(proto2_20160225.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto2_20160225.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160225.SiblingEnum(0))),
want: fileDescP2_20160225.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto2_20160225.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160225.Message_ChildEnum(0))),
want: fileDescP2_20160225.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.SiblingMessage))),
want: fileDescP2_20160225.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ChildMessage))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message))),
want: fileDescP2_20160225.Messages().ByName("Message"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_NamedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_NamedGroup))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("NamedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OptionalGroup))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("OptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RequiredGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RequiredGroup))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("RequiredGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RepeatedGroup))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("RepeatedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OneofGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OneofGroup))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("OneofGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionOptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionOptionalGroup))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionRepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionRepeatedGroup))),
want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"),
}}...)
fileDescP3_20160225 := mustLoadFileDesc(new(proto3_20160225.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto3_20160225.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160225.SiblingEnum(0))),
want: fileDescP3_20160225.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto3_20160225.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160225.Message_ChildEnum(0))),
want: fileDescP3_20160225.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20160225.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160225.SiblingMessage))),
want: fileDescP3_20160225.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message_ChildMessage))),
want: fileDescP3_20160225.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message))),
want: fileDescP3_20160225.Messages().ByName("Message"),
}}...)
fileDescP2_20160519 := mustLoadFileDesc(new(proto2_20160519.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto2_20160519.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160519.SiblingEnum(0))),
want: fileDescP2_20160519.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto2_20160519.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160519.Message_ChildEnum(0))),
want: fileDescP2_20160519.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.SiblingMessage))),
want: fileDescP2_20160519.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ChildMessage))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message))),
want: fileDescP2_20160519.Messages().ByName("Message"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_NamedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_NamedGroup))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("NamedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OptionalGroup))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("OptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RequiredGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RequiredGroup))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("RequiredGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RepeatedGroup))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("RepeatedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OneofGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OneofGroup))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("OneofGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionOptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionOptionalGroup))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionRepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionRepeatedGroup))),
want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"),
}}...)
fileDescP3_20160519 := mustLoadFileDesc(new(proto3_20160519.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto3_20160519.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160519.SiblingEnum(0))),
want: fileDescP3_20160519.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto3_20160519.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160519.Message_ChildEnum(0))),
want: fileDescP3_20160519.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20160519.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160519.SiblingMessage))),
want: fileDescP3_20160519.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message_ChildMessage))),
want: fileDescP3_20160519.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message))),
want: fileDescP3_20160519.Messages().ByName("Message"),
}}...)
fileDescP2_20180125 := mustLoadFileDesc(new(proto2_20180125.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto2_20180125.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180125.SiblingEnum(0))),
want: fileDescP2_20180125.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto2_20180125.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180125.Message_ChildEnum(0))),
want: fileDescP2_20180125.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.SiblingMessage))),
want: fileDescP2_20180125.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ChildMessage))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message))),
want: fileDescP2_20180125.Messages().ByName("Message"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_NamedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_NamedGroup))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("NamedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OptionalGroup))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("OptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RequiredGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RequiredGroup))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("RequiredGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RepeatedGroup))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("RepeatedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OneofGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OneofGroup))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("OneofGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionOptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionOptionalGroup))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionRepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionRepeatedGroup))),
want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"),
}}...)
fileDescP3_20180125 := mustLoadFileDesc(new(proto3_20180125.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto3_20180125.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180125.SiblingEnum(0))),
want: fileDescP3_20180125.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto3_20180125.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180125.Message_ChildEnum(0))),
want: fileDescP3_20180125.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180125.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180125.SiblingMessage))),
want: fileDescP3_20180125.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message_ChildMessage))),
want: fileDescP3_20180125.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message))),
want: fileDescP3_20180125.Messages().ByName("Message"),
}}...)
fileDescP2_20180430 := mustLoadFileDesc(new(proto2_20180430.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto2_20180430.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180430.SiblingEnum(0))),
want: fileDescP2_20180430.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto2_20180430.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180430.Message_ChildEnum(0))),
want: fileDescP2_20180430.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.SiblingMessage))),
want: fileDescP2_20180430.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ChildMessage))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message))),
want: fileDescP2_20180430.Messages().ByName("Message"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_NamedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_NamedGroup))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("NamedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OptionalGroup))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("OptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RequiredGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RequiredGroup))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("RequiredGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RepeatedGroup))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("RepeatedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OneofGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OneofGroup))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("OneofGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionOptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionOptionalGroup))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionRepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionRepeatedGroup))),
want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"),
}}...)
fileDescP3_20180430 := mustLoadFileDesc(new(proto3_20180430.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto3_20180430.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180430.SiblingEnum(0))),
want: fileDescP3_20180430.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto3_20180430.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180430.Message_ChildEnum(0))),
want: fileDescP3_20180430.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180430.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180430.SiblingMessage))),
want: fileDescP3_20180430.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message_ChildMessage))),
want: fileDescP3_20180430.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message))),
want: fileDescP3_20180430.Messages().ByName("Message"),
}}...)
fileDescP2_20180814 := mustLoadFileDesc(new(proto2_20180814.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto2_20180814.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180814.SiblingEnum(0))),
want: fileDescP2_20180814.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto2_20180814.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180814.Message_ChildEnum(0))),
want: fileDescP2_20180814.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.SiblingMessage))),
want: fileDescP2_20180814.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ChildMessage))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message))),
want: fileDescP2_20180814.Messages().ByName("Message"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_NamedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_NamedGroup))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("NamedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OptionalGroup))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("OptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RequiredGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RequiredGroup))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("RequiredGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RepeatedGroup))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("RepeatedGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OneofGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OneofGroup))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("OneofGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionOptionalGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionOptionalGroup))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionRepeatedGroup))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionRepeatedGroup))),
want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"),
}}...)
fileDescP3_20180814 := mustLoadFileDesc(new(proto3_20180814.Message).Descriptor())
tests = append(tests, []struct{ got, want pref.Descriptor }{{
got: loadEnumDesc(reflect.TypeOf(proto3_20180814.SiblingEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180814.SiblingEnum(0))),
want: fileDescP3_20180814.Enums().ByName("SiblingEnum"),
}, {
got: loadEnumDesc(reflect.TypeOf(proto3_20180814.Message_ChildEnum(0))),
got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180814.Message_ChildEnum(0))),
want: fileDescP3_20180814.Messages().ByName("Message").Enums().ByName("ChildEnum"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180814.SiblingMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180814.SiblingMessage))),
want: fileDescP3_20180814.Messages().ByName("SiblingMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message_ChildMessage))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message_ChildMessage))),
want: fileDescP3_20180814.Messages().ByName("Message").Messages().ByName("ChildMessage"),
}, {
got: loadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message))),
got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message))),
want: fileDescP3_20180814.Messages().ByName("Message"),
}}...)
@ -383,13 +383,16 @@ func TestLegacyDescriptor(t *testing.T) {
}
}
type legacyUnknownMessage struct {
type legacyTestMessage struct {
XXX_unrecognized []byte
protoV1.XXX_InternalExtensions
}
func (*legacyUnknownMessage) ExtensionRangeArray() []protoV1.ExtensionRange {
return []protoV1.ExtensionRange{{Start: 10, End: 20}, {Start: 40, End: 80}}
func (*legacyTestMessage) Reset() {}
func (*legacyTestMessage) String() string { return "" }
func (*legacyTestMessage) ProtoMessage() {}
func (*legacyTestMessage) ExtensionRangeArray() []protoV1.ExtensionRange {
return []protoV1.ExtensionRange{{Start: 10, End: 20}, {Start: 40, End: 80}, {Start: 10000, End: 20000}}
}
func TestLegacyUnknown(t *testing.T) {
@ -422,8 +425,8 @@ func TestLegacyUnknown(t *testing.T) {
return out
}
m := new(legacyUnknownMessage)
fs := new(MessageType).MessageOf(m).UnknownFields()
m := new(legacyTestMessage)
fs := MessageOf(m).UnknownFields()
if got, want := fs.Len(), 0; got != want {
t.Errorf("Len() = %d, want %d", got, want)
@ -618,3 +621,284 @@ func TestLegacyUnknown(t *testing.T) {
return i < 2
})
}
func TestLegactExtensions(t *testing.T) {
extensions := []pref.ExtensionType{
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: (*bool)(nil),
Field: 10000,
Name: "fizz.buzz.optional_bool",
Tag: "varint,10000,opt,name=optional_bool,json=optionalBool,def=1",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: (*int32)(nil),
Field: 10001,
Name: "fizz.buzz.optional_int32",
Tag: "varint,10001,opt,name=optional_int32,json=optionalInt32,def=-12345",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: (*uint32)(nil),
Field: 10002,
Name: "fizz.buzz.optional_uint32",
Tag: "varint,10002,opt,name=optional_uint32,json=optionalUint32,def=3200",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: (*float32)(nil),
Field: 10003,
Name: "fizz.buzz.optional_float",
Tag: "fixed32,10003,opt,name=optional_float,json=optionalFloat,def=3.14159",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: (*string)(nil),
Field: 10004,
Name: "fizz.buzz.optional_string",
Tag: "bytes,10004,opt,name=optional_string,json=optionalString,def=hello, \"world!\"\n",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]byte)(nil),
Field: 10005,
Name: "fizz.buzz.optional_bytes",
Tag: "bytes,10005,opt,name=optional_bytes,json=optionalBytes,def=dead\\336\\255\\276\\357beef",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: (*proto2_20180125.Message_ChildEnum)(nil),
Field: 10006,
Name: "fizz.buzz.optional_enum",
Tag: "varint,10006,opt,name=optional_enum,json=optionalEnum,enum=google.golang.org.proto2_20180125.Message_ChildEnum,def=0",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: (*proto2_20180125.Message_ChildMessage)(nil),
Field: 10007,
Name: "fizz.buzz.optional_message",
Tag: "bytes,10007,opt,name=optional_message,json=optionalMessage",
Filename: "fizz/buzz/test.proto",
}),
// TODO: Test v2 enum and messages.
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]bool)(nil),
Field: 10008,
Name: "fizz.buzz.repeated_bool",
Tag: "varint,10008,rep,name=repeated_bool,json=repeatedBool",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]int32)(nil),
Field: 10009,
Name: "fizz.buzz.repeated_int32",
Tag: "varint,10009,rep,name=repeated_int32,json=repeatedInt32",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]uint32)(nil),
Field: 10010,
Name: "fizz.buzz.repeated_uint32",
Tag: "varint,10010,rep,name=repeated_uint32,json=repeatedUint32",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]float32)(nil),
Field: 10011,
Name: "fizz.buzz.repeated_float",
Tag: "fixed32,10011,rep,name=repeated_float,json=repeatedFloat",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]string)(nil),
Field: 10012,
Name: "fizz.buzz.repeated_string",
Tag: "bytes,10012,rep,name=repeated_string,json=repeatedString",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([][]byte)(nil),
Field: 10013,
Name: "fizz.buzz.repeated_bytes",
Tag: "bytes,10013,rep,name=repeated_bytes,json=repeatedBytes",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]proto2_20180125.Message_ChildEnum)(nil),
Field: 10014,
Name: "fizz.buzz.repeated_enum",
Tag: "varint,10014,rep,name=repeated_enum,json=repeatedEnum,enum=google.golang.org.proto2_20180125.Message_ChildEnum",
Filename: "fizz/buzz/test.proto",
}),
legacyExtensionTypeOf(&protoV1.ExtensionDesc{
ExtendedType: (*legacyTestMessage)(nil),
ExtensionType: ([]*proto2_20180125.Message_ChildMessage)(nil),
Field: 10015,
Name: "fizz.buzz.repeated_message",
Tag: "bytes,10015,rep,name=repeated_message,json=repeatedMessage",
Filename: "fizz/buzz/test.proto",
}),
// TODO: Test v2 enum and messages.
}
opts := cmp.Options{cmp.Comparer(func(x, y *proto2_20180125.Message_ChildMessage) bool {
return x == y // pointer compare messages for object identity
})}
m := new(legacyTestMessage)
fs := MessageOf(m).KnownFields()
ts := fs.ExtensionTypes()
if n := fs.Len(); n != 0 {
t.Errorf("KnownFields.Len() = %v, want 0", n)
}
if n := ts.Len(); n != 0 {
t.Errorf("ExtensionFieldTypes.Len() = %v, want 0", n)
}
// Register all the extension types.
for _, xt := range extensions {
ts.Register(xt)
}
// Check that getting the zero value returns the default value for scalars,
// nil for singular messages, and an empty vector for repeated fields.
defaultValues := []interface{}{
bool(true),
int32(-12345),
uint32(3200),
float32(3.14159),
string("hello, \"world!\"\n"),
[]byte("dead\xde\xad\xbe\xefbeef"),
proto2_20180125.Message_ALPHA,
nil,
new([]bool),
new([]int32),
new([]uint32),
new([]float32),
new([]string),
new([][]byte),
new([]proto2_20180125.Message_ChildEnum),
new([]*proto2_20180125.Message_ChildMessage),
}
for i, xt := range extensions {
var got interface{}
v := fs.Get(xt.Number())
if xt.Cardinality() != pref.Repeated && xt.Kind() == pref.MessageKind {
got = v.Interface()
} else {
got = xt.InterfaceOf(v) // TODO: Simplify this if InterfaceOf allows nil
}
want := defaultValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
t.Errorf("KnownFields.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
}
}
// All fields should be unpopulated.
for _, xt := range extensions {
if fs.Has(xt.Number()) {
t.Errorf("KnownFields.Has(%d) = true, want false", xt.Number())
}
}
// Set some values and append to values to the vectors.
m1 := &proto2_20180125.Message_ChildMessage{F1: protoV1.String("m1")}
m2 := &proto2_20180125.Message_ChildMessage{F1: protoV1.String("m2")}
setValues := []interface{}{
bool(false),
int32(-54321),
uint32(6400),
float32(2.71828),
string("goodbye, \"world!\"\n"),
[]byte("live\xde\xad\xbe\xefchicken"),
proto2_20180125.Message_CHARLIE,
m1,
&[]bool{true},
&[]int32{-1000},
&[]uint32{1280},
&[]float32{1.6180},
&[]string{"zero"},
&[][]byte{[]byte("zero")},
&[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO},
&[]*proto2_20180125.Message_ChildMessage{m2},
}
for i, xt := range extensions {
fs.Set(xt.Number(), xt.ValueOf(setValues[i]))
}
for i, xt := range extensions[len(extensions)/2:] {
v := extensions[i].ValueOf(setValues[i])
fs.Get(xt.Number()).Vector().Append(v)
}
// Get the values and check for equality.
getValues := []interface{}{
bool(false),
int32(-54321),
uint32(6400),
float32(2.71828),
string("goodbye, \"world!\"\n"),
[]byte("live\xde\xad\xbe\xefchicken"),
proto2_20180125.Message_ChildEnum(proto2_20180125.Message_CHARLIE),
m1,
&[]bool{true, false},
&[]int32{-1000, -54321},
&[]uint32{1280, 6400},
&[]float32{1.6180, 2.71828},
&[]string{"zero", "goodbye, \"world!\"\n"},
&[][]byte{[]byte("zero"), []byte("live\xde\xad\xbe\xefchicken")},
&[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO, proto2_20180125.Message_CHARLIE},
&[]*proto2_20180125.Message_ChildMessage{m2, m1},
}
for i, xt := range extensions {
got := xt.InterfaceOf(fs.Get(xt.Number()))
want := getValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
t.Errorf("KnownFields.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
}
}
if n := fs.Len(); n != 16 {
t.Errorf("KnownFields.Len() = %v, want 0", n)
}
if n := ts.Len(); n != 16 {
t.Errorf("ExtensionFieldTypes.Len() = %v, want 16", n)
}
// Clear the field for all extension types.
for _, xt := range extensions {
fs.Clear(xt.Number())
}
if n := fs.Len(); n != 0 {
t.Errorf("KnownFields.Len() = %v, want 0", n)
}
if n := ts.Len(); n != 16 {
t.Errorf("ExtensionFieldTypes.Len() = %v, want 16", n)
}
// De-register all extension types.
for _, xt := range extensions {
ts.Remove(xt)
}
if n := fs.Len(); n != 0 {
t.Errorf("KnownFields.Len() = %v, want 0", n)
}
if n := ts.Len(); n != 0 {
t.Errorf("ExtensionFieldTypes.Len() = %v, want 0", n)
}
}

View File

@ -29,7 +29,7 @@ func makeLegacyUnknownFieldsFunc(t reflect.Type) func(p *messageDataType) pref.U
if extFunc != nil {
return func(p *messageDataType) pref.UnknownFields {
return &legacyUnknownBytesAndExtensionMap{
unkFunc(p), extFunc(p), p.mi.Desc.ExtensionRanges(),
unkFunc(p), extFunc(p), p.mi.Type.ExtensionRanges(),
}
}
}

View File

@ -12,24 +12,30 @@ import (
"sync"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
ptype "github.com/golang/protobuf/v2/reflect/prototype"
)
// MessageOf returns the protoreflect.Message interface over p.
// If p already implements proto.Message, then it directly calls the
// ProtoReflect method, otherwise it wraps the legacy v1 message to implement
// the v2 reflective interface.
func MessageOf(p interface{}) pref.Message {
if m, ok := p.(pref.ProtoMessage); ok {
return m.ProtoReflect()
}
return legacyWrapMessage(reflect.ValueOf(p)).ProtoReflect()
}
// MessageType provides protobuf related functionality for a given Go type
// that represents a message. A given instance of MessageType is tied to
// exactly one Go type, which must be a pointer to a struct type.
type MessageType struct {
// Desc is an optionally provided message descriptor. If nil, the descriptor
// is lazily derived from the Go type information of generated messages
// for the v1 API.
//
// Type is the underlying message type and must be populated.
// Once set, this field must never be mutated.
Desc pref.MessageDescriptor
Type pref.MessageType
once sync.Once // protects all unexported fields
goType reflect.Type // pointer to struct
pbType pref.MessageType // only valid if goType does not implement proto.Message
goType reflect.Type // pointer to struct
// TODO: Split fields into dense and sparse maps similar to the current
// table-driven implementation in v1?
@ -45,33 +51,12 @@ type MessageType struct {
// It must be called at the start of every exported method.
func (mi *MessageType) init(p interface{}) {
mi.once.Do(func() {
v := reflect.ValueOf(p)
t := v.Type()
t := reflect.TypeOf(p)
if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
panic(fmt.Sprintf("got %v, want *struct kind", t))
}
mi.goType = t
// Derive the message descriptor if unspecified.
if mi.Desc == nil {
mi.Desc = loadMessageDesc(t)
}
// Initialize the Go message type wrapper if the Go type does not
// implement the proto.Message interface.
//
// Otherwise, we assume that the Go type manually implements the
// interface and is internally consistent such that:
// goType == reflect.New(goType.Elem()).Interface().(proto.Message).ProtoReflect().Type().GoType()
//
// Generated code ensures that this property holds.
if _, ok := p.(pref.ProtoMessage); !ok {
mi.pbType = ptype.GoMessage(mi.Desc, func(pref.MessageType) pref.ProtoMessage {
p := reflect.New(t.Elem()).Interface()
return (*legacyMessageWrapper)(mi.dataTypeOf(p))
})
}
mi.makeKnownFieldsFunc(t.Elem())
mi.makeUnknownFieldsFunc(t.Elem())
mi.makeExtensionFieldsFunc(t.Elem())
@ -86,10 +71,9 @@ func (mi *MessageType) init(p interface{}) {
}
}
// makeKnownFieldsFunc generates per-field functions for all operations
// to be performed on each field. It takes in a reflect.Type representing the
// Go struct, and a protoreflect.MessageDescriptor to match with the fields
// in the struct.
// makeKnownFieldsFunc generates functions for operations that can be performed
// on each protobuf message field. It takes in a reflect.Type representing the
// Go struct and matches message fields with struct fields.
//
// This code assumes that the struct is well-formed and panics if there are
// any discrepancies.
@ -136,8 +120,8 @@ fieldLoop:
}
mi.fields = map[pref.FieldNumber]*fieldInfo{}
for i := 0; i < mi.Desc.Fields().Len(); i++ {
fd := mi.Desc.Fields().Get(i)
for i := 0; i < mi.Type.Fields().Len(); i++ {
fd := mi.Type.Fields().Get(i)
fs := fields[fd.Number()]
var fi fieldInfo
switch {
@ -169,33 +153,25 @@ func (mi *MessageType) makeUnknownFieldsFunc(t reflect.Type) {
}
func (mi *MessageType) makeExtensionFieldsFunc(t reflect.Type) {
// TODO
if f := makeLegacyExtensionFieldsFunc(t); f != nil {
mi.extensionFields = f
return
}
mi.extensionFields = func(*messageDataType) pref.KnownFields {
return emptyExtensionFields{}
}
}
func (mi *MessageType) MessageOf(p interface{}) pref.Message {
mi.init(p)
if m, ok := p.(pref.ProtoMessage); ok {
// We assume p properly implements protoreflect.Message.
// See the comment in MessageType.init regarding pbType.
return m.ProtoReflect()
}
return (*legacyMessageWrapper)(mi.dataTypeOf(p))
}
func (mi *MessageType) KnownFieldsOf(p interface{}) pref.KnownFields {
mi.init(p)
return (*knownFields)(mi.dataTypeOf(p))
}
func (mi *MessageType) UnknownFieldsOf(p interface{}) pref.UnknownFields {
mi.init(p)
return mi.unknownFields(mi.dataTypeOf(p))
}
func (mi *MessageType) dataTypeOf(p interface{}) *messageDataType {
mi.init(p)
return &messageDataType{pointerOfIface(&p), mi}
}
@ -249,20 +225,30 @@ func (fs *knownFields) Set(n pref.FieldNumber, v pref.Value) {
fi.set(fs.p, v)
return
}
fs.extensionFields().Set(n, v)
if fs.mi.Type.ExtensionRanges().Has(n) {
fs.extensionFields().Set(n, v)
return
}
panic(fmt.Sprintf("invalid field: %d", n))
}
func (fs *knownFields) Clear(n pref.FieldNumber) {
if fi := fs.mi.fields[n]; fi != nil {
fi.clear(fs.p)
return
}
fs.extensionFields().Clear(n)
if fs.mi.Type.ExtensionRanges().Has(n) {
fs.extensionFields().Clear(n)
return
}
}
func (fs *knownFields) Mutable(n pref.FieldNumber) pref.Mutable {
if fi := fs.mi.fields[n]; fi != nil {
return fi.mutable(fs.p)
}
return fs.extensionFields().Mutable(n)
if fs.mi.Type.ExtensionRanges().Has(n) {
return fs.extensionFields().Mutable(n)
}
panic(fmt.Sprintf("invalid field: %d", n))
}
func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
for n, fi := range fs.mi.fields {
@ -291,14 +277,14 @@ func (emptyUnknownFields) IsSupported() bool { r
type emptyExtensionFields struct{}
func (emptyExtensionFields) Len() int { return 0 }
func (emptyExtensionFields) Has(pref.FieldNumber) bool { return false }
func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value { return pref.Value{} }
func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value) { panic("extensions not supported") }
func (emptyExtensionFields) Clear(pref.FieldNumber) { return } // noop
func (emptyExtensionFields) Mutable(pref.FieldNumber) pref.Mutable { panic("extensions not supported") }
func (emptyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) { return }
func (emptyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { return emptyExtensionTypes{} }
func (emptyExtensionFields) Len() int { return 0 }
func (emptyExtensionFields) Has(pref.FieldNumber) bool { return false }
func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value { return pref.Value{} }
func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value) { panic("extensions not supported") }
func (emptyExtensionFields) Clear(pref.FieldNumber) { return } // noop
func (emptyExtensionFields) Mutable(pref.FieldNumber) pref.Mutable { panic("extensions not supported") }
func (emptyExtensionFields) Range(func(pref.FieldNumber, pref.Value) bool) { return }
func (emptyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { return emptyExtensionTypes{} }
type emptyExtensionTypes struct{}

View File

@ -9,7 +9,7 @@ import (
"reflect"
"github.com/golang/protobuf/v2/internal/flags"
"github.com/golang/protobuf/v2/internal/value"
pvalue "github.com/golang/protobuf/v2/internal/value"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
)
@ -42,7 +42,7 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot refle
if !reflect.PtrTo(ot).Implements(ft) {
panic(fmt.Sprintf("invalid type: %v does not implement %v", ot, ft))
}
conv := value.NewLegacyConverter(ot.Field(0).Type, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
conv := newConverter(ot.Field(0).Type, fd.Kind())
fieldOffset := offsetOf(fs)
// TODO: Implement unsafe fast path?
return fieldInfo{
@ -106,8 +106,8 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo
if ft.Kind() != reflect.Map {
panic(fmt.Sprintf("invalid type: got %v, want map kind", ft))
}
keyConv := value.NewLegacyConverter(ft.Key(), fd.MessageType().Fields().ByNumber(1).Kind(), wrapLegacyEnum, wrapLegacyMessage)
valConv := value.NewLegacyConverter(ft.Elem(), fd.MessageType().Fields().ByNumber(2).Kind(), wrapLegacyEnum, wrapLegacyMessage)
keyConv := newConverter(ft.Key(), fd.MessageType().Fields().ByNumber(1).Kind())
valConv := newConverter(ft.Elem(), fd.MessageType().Fields().ByNumber(2).Kind())
fieldOffset := offsetOf(fs)
// TODO: Implement unsafe fast path?
return fieldInfo{
@ -117,11 +117,11 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo
},
get: func(p pointer) pref.Value {
v := p.apply(fieldOffset).asType(fs.Type).Interface()
return pref.ValueOf(value.MapOf(v, keyConv, valConv))
return pref.ValueOf(pvalue.MapOf(v, keyConv, valConv))
},
set: func(p pointer, v pref.Value) {
rv := p.apply(fieldOffset).asType(fs.Type).Elem()
rv.Set(reflect.ValueOf(v.Map().(value.Unwrapper).Unwrap()))
rv.Set(reflect.ValueOf(v.Map().(pvalue.Unwrapper).Unwrap()).Elem())
},
clear: func(p pointer) {
rv := p.apply(fieldOffset).asType(fs.Type).Elem()
@ -129,7 +129,7 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo
},
mutable: func(p pointer) pref.Mutable {
v := p.apply(fieldOffset).asType(fs.Type).Interface()
return value.MapOf(v, keyConv, valConv)
return pvalue.MapOf(v, keyConv, valConv)
},
}
}
@ -139,7 +139,7 @@ func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
if ft.Kind() != reflect.Slice {
panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
}
conv := value.NewLegacyConverter(ft.Elem(), fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
conv := newConverter(ft.Elem(), fd.Kind())
fieldOffset := offsetOf(fs)
// TODO: Implement unsafe fast path?
return fieldInfo{
@ -149,11 +149,11 @@ func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
},
get: func(p pointer) pref.Value {
v := p.apply(fieldOffset).asType(fs.Type).Interface()
return pref.ValueOf(value.VectorOf(v, conv))
return pref.ValueOf(pvalue.VectorOf(v, conv))
},
set: func(p pointer, v pref.Value) {
rv := p.apply(fieldOffset).asType(fs.Type).Elem()
rv.Set(reflect.ValueOf(v.Vector().(value.Unwrapper).Unwrap()))
rv.Set(reflect.ValueOf(v.Vector().(pvalue.Unwrapper).Unwrap()).Elem())
},
clear: func(p pointer) {
rv := p.apply(fieldOffset).asType(fs.Type).Elem()
@ -161,7 +161,7 @@ func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
},
mutable: func(p pointer) pref.Mutable {
v := p.apply(fieldOffset).asType(fs.Type).Interface()
return value.VectorOf(v, conv)
return pvalue.VectorOf(v, conv)
},
}
}
@ -179,7 +179,7 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
ft = ft.Elem()
}
}
conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
conv := newConverter(ft, fd.Kind())
fieldOffset := offsetOf(fs)
// TODO: Implement unsafe fast path?
return fieldInfo{
@ -244,7 +244,7 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn
func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
conv := newConverter(ft, fd.Kind())
fieldOffset := offsetOf(fs)
// TODO: Implement unsafe fast path?
return fieldInfo{
@ -282,3 +282,12 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldI
},
}
}
// newConverter calls value.NewLegacyConverter with the necessary constructor
// functions for legacy enum and message support.
func newConverter(t reflect.Type, k pref.Kind) pvalue.Converter {
messageType := func(t reflect.Type) pref.MessageType {
return legacyLoadMessageType(t).Type
}
return pvalue.NewLegacyConverter(t, k, legacyLoadEnumType, messageType, legacyWrapMessage)
}

View File

@ -13,7 +13,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/golang/protobuf/proto"
protoV1 "github.com/golang/protobuf/proto"
descriptorV1 "github.com/golang/protobuf/protoc-gen-go/descriptor"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@ -46,29 +45,8 @@ type (
MapStrings map[MyString]MyString
MapBytes map[MyString]MyBytes
MyEnumV1 pref.EnumNumber
MyEnumV2 string
myEnumV2 MyEnumV2
MyMessageV1 struct {
// SubMessage *Message
}
MyMessageV2 map[pref.FieldNumber]pref.Value
myMessageV2 MyMessageV2
)
func (e MyEnumV2) ProtoReflect() pref.Enum { return myEnumV2(e) }
func (e myEnumV2) Type() pref.EnumType { return nil } // TODO
func (e myEnumV2) Number() pref.EnumNumber { return 0 } // TODO
func (m *MyMessageV2) ProtoReflect() pref.Message { return (*myMessageV2)(m) }
func (m *myMessageV2) Type() pref.MessageType { return nil } // TODO
func (m *myMessageV2) KnownFields() pref.KnownFields { return nil } // TODO
func (m *myMessageV2) UnknownFields() pref.UnknownFields { return nil } // TODO
func (m *myMessageV2) Interface() pref.ProtoMessage { return (*MyMessageV2)(m) }
func (m *myMessageV2) ProtoMutable() {}
// List of test operations to perform on messages, vectors, or maps.
type (
messageOp interface{} // equalMessage | hasFields | getFields | setFields | clearFields | vectorFields | mapFields
@ -143,8 +121,8 @@ type ScalarProto2 struct {
MyBytesA *MyString `protobuf:"22"`
}
func TestScalarProto2(t *testing.T) {
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
var scalarProto2Type = MessageType{Type: ptype.GoMessage(
mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: "ScalarProto2",
Fields: []ptype.Field{
@ -172,9 +150,21 @@ func TestScalarProto2(t *testing.T) {
{Name: "f21", Number: 21, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("21"))},
{Name: "f22", Number: 22, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("22"))},
},
})}
}),
func(pref.MessageType) pref.ProtoMessage {
return new(ScalarProto2)
},
)}
testMessage(t, nil, mi.MessageOf(&ScalarProto2{}), messageOps{
func (m *ScalarProto2) Type() pref.MessageType { return scalarProto2Type.Type }
func (m *ScalarProto2) KnownFields() pref.KnownFields { return scalarProto2Type.KnownFieldsOf(m) }
func (m *ScalarProto2) UnknownFields() pref.UnknownFields { return scalarProto2Type.UnknownFieldsOf(m) }
func (m *ScalarProto2) Interface() pref.ProtoMessage { return m }
func (m *ScalarProto2) ProtoReflect() pref.Message { return m }
func (m *ScalarProto2) ProtoMutable() {}
func TestScalarProto2(t *testing.T) {
testMessage(t, nil, &ScalarProto2{}, messageOps{
hasFields{
1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
@ -191,15 +181,15 @@ func TestScalarProto2(t *testing.T) {
1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true,
12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true,
},
equalMessage(mi.MessageOf(&ScalarProto2{
equalMessage(&ScalarProto2{
new(bool), new(int32), new(int64), new(uint32), new(uint64), new(float32), new(float64), new(string), []byte{}, []byte{}, new(string),
new(MyBool), new(MyInt32), new(MyInt64), new(MyUint32), new(MyUint64), new(MyFloat32), new(MyFloat64), new(MyString), MyBytes{}, MyBytes{}, new(MyString),
})),
}),
clearFields{
1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true,
12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true,
},
equalMessage(mi.MessageOf(&ScalarProto2{})),
equalMessage(&ScalarProto2{}),
})
}
@ -229,8 +219,8 @@ type ScalarProto3 struct {
MyBytesA MyString `protobuf:"22"`
}
func TestScalarProto3(t *testing.T) {
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
var scalarProto3Type = MessageType{Type: ptype.GoMessage(
mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto3,
FullName: "ScalarProto3",
Fields: []ptype.Field{
@ -258,9 +248,21 @@ func TestScalarProto3(t *testing.T) {
{Name: "f21", Number: 21, Cardinality: pref.Optional, Kind: pref.BytesKind},
{Name: "f22", Number: 22, Cardinality: pref.Optional, Kind: pref.BytesKind},
},
})}
}),
func(pref.MessageType) pref.ProtoMessage {
return new(ScalarProto3)
},
)}
testMessage(t, nil, mi.MessageOf(&ScalarProto3{}), messageOps{
func (m *ScalarProto3) Type() pref.MessageType { return scalarProto3Type.Type }
func (m *ScalarProto3) KnownFields() pref.KnownFields { return scalarProto3Type.KnownFieldsOf(m) }
func (m *ScalarProto3) UnknownFields() pref.UnknownFields { return scalarProto3Type.UnknownFieldsOf(m) }
func (m *ScalarProto3) Interface() pref.ProtoMessage { return m }
func (m *ScalarProto3) ProtoReflect() pref.Message { return m }
func (m *ScalarProto3) ProtoMutable() {}
func TestScalarProto3(t *testing.T) {
testMessage(t, nil, &ScalarProto3{}, messageOps{
hasFields{
1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
@ -277,7 +279,7 @@ func TestScalarProto3(t *testing.T) {
1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false,
12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false,
},
equalMessage(mi.MessageOf(&ScalarProto3{})),
equalMessage(&ScalarProto3{}),
setFields{
1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V([]byte("10")), 11: V([]byte("11")),
12: V(bool(true)), 13: V(int32(13)), 14: V(int64(14)), 15: V(uint32(15)), 16: V(uint64(16)), 17: V(float32(17)), 18: V(float64(18)), 19: V(string("19")), 20: V(string("20")), 21: V([]byte("21")), 22: V([]byte("22")),
@ -286,19 +288,19 @@ func TestScalarProto3(t *testing.T) {
1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true,
12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true,
},
equalMessage(mi.MessageOf(&ScalarProto3{
equalMessage(&ScalarProto3{
true, 2, 3, 4, 5, 6, 7, "8", []byte("9"), []byte("10"), "11",
true, 13, 14, 15, 16, 17, 18, "19", []byte("20"), []byte("21"), "22",
})),
}),
clearFields{
1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true,
12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true,
},
equalMessage(mi.MessageOf(&ScalarProto3{})),
equalMessage(&ScalarProto3{}),
})
}
type RepeatedScalars struct {
type ListScalars struct {
Bools []bool `protobuf:"1"`
Int32s []int32 `protobuf:"2"`
Int64s []int64 `protobuf:"3"`
@ -322,10 +324,10 @@ type RepeatedScalars struct {
MyBytes4 VectorStrings `protobuf:"19"`
}
func TestRepeatedScalars(t *testing.T) {
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
var listScalarsType = MessageType{Type: ptype.GoMessage(
mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: "RepeatedScalars",
FullName: "ListScalars",
Fields: []ptype.Field{
{Name: "f1", Number: 1, Cardinality: pref.Repeated, Kind: pref.BoolKind},
{Name: "f2", Number: 2, Cardinality: pref.Repeated, Kind: pref.Int32Kind},
@ -349,12 +351,24 @@ func TestRepeatedScalars(t *testing.T) {
{Name: "f18", Number: 18, Cardinality: pref.Repeated, Kind: pref.BytesKind},
{Name: "f19", Number: 19, Cardinality: pref.Repeated, Kind: pref.BytesKind},
},
})}
}),
func(pref.MessageType) pref.ProtoMessage {
return new(ListScalars)
},
)}
empty := mi.MessageOf(&RepeatedScalars{})
func (m *ListScalars) Type() pref.MessageType { return listScalarsType.Type }
func (m *ListScalars) KnownFields() pref.KnownFields { return listScalarsType.KnownFieldsOf(m) }
func (m *ListScalars) UnknownFields() pref.UnknownFields { return listScalarsType.UnknownFieldsOf(m) }
func (m *ListScalars) Interface() pref.ProtoMessage { return m }
func (m *ListScalars) ProtoReflect() pref.Message { return m }
func (m *ListScalars) ProtoMutable() {}
func TestListScalars(t *testing.T) {
empty := &ListScalars{}
emptyFS := empty.KnownFields()
want := mi.MessageOf(&RepeatedScalars{
want := &ListScalars{
Bools: []bool{true, false, true},
Int32s: []int32{2, math.MinInt32, math.MaxInt32},
Int64s: []int64{3, math.MinInt64, math.MaxInt64},
@ -376,10 +390,10 @@ func TestRepeatedScalars(t *testing.T) {
MyStrings4: VectorBytes{[]byte("17"), nil, []byte("seventeen")},
MyBytes3: VectorBytes{[]byte("18"), nil, []byte("eighteen")},
MyBytes4: VectorStrings{"19", "", "nineteen"},
})
}
wantFS := want.KnownFields()
testMessage(t, nil, mi.MessageOf(&RepeatedScalars{}), messageOps{
testMessage(t, nil, &ListScalars{}, messageOps{
hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false},
getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19)},
setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19)},
@ -470,25 +484,26 @@ type MapScalars struct {
MyBytes4 MapStrings `protobuf:"25"`
}
func TestMapScalars(t *testing.T) {
mustMakeMapEntry := func(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field {
return ptype.Field{
Name: pref.Name(fmt.Sprintf("f%d", n)),
Number: n,
Cardinality: pref.Repeated,
Kind: pref.MessageKind,
MessageType: mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: pref.FullName(fmt.Sprintf("MapScalars.F%dEntry", n)),
Fields: []ptype.Field{
{Name: "key", Number: 1, Cardinality: pref.Optional, Kind: keyKind},
{Name: "value", Number: 2, Cardinality: pref.Optional, Kind: valKind},
},
Options: &descriptorV1.MessageOptions{MapEntry: protoV1.Bool(true)},
}),
}
func mustMakeMapEntry(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field {
return ptype.Field{
Name: pref.Name(fmt.Sprintf("f%d", n)),
Number: n,
Cardinality: pref.Repeated,
Kind: pref.MessageKind,
MessageType: mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: pref.FullName(fmt.Sprintf("MapScalars.F%dEntry", n)),
Fields: []ptype.Field{
{Name: "key", Number: 1, Cardinality: pref.Optional, Kind: keyKind},
{Name: "value", Number: 2, Cardinality: pref.Optional, Kind: valKind},
},
Options: &descriptorV1.MessageOptions{MapEntry: protoV1.Bool(true)},
}),
}
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
}
var mapScalarsType = MessageType{Type: ptype.GoMessage(
mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: "MapScalars",
Fields: []ptype.Field{
@ -521,12 +536,24 @@ func TestMapScalars(t *testing.T) {
mustMakeMapEntry(24, pref.StringKind, pref.BytesKind),
mustMakeMapEntry(25, pref.StringKind, pref.BytesKind),
},
})}
}),
func(pref.MessageType) pref.ProtoMessage {
return new(MapScalars)
},
)}
empty := mi.MessageOf(&MapScalars{})
func (m *MapScalars) Type() pref.MessageType { return mapScalarsType.Type }
func (m *MapScalars) KnownFields() pref.KnownFields { return mapScalarsType.KnownFieldsOf(m) }
func (m *MapScalars) UnknownFields() pref.UnknownFields { return mapScalarsType.UnknownFieldsOf(m) }
func (m *MapScalars) Interface() pref.ProtoMessage { return m }
func (m *MapScalars) ProtoReflect() pref.Message { return m }
func (m *MapScalars) ProtoMutable() {}
func TestMapScalars(t *testing.T) {
empty := &MapScalars{}
emptyFS := empty.KnownFields()
want := mi.MessageOf(&MapScalars{
want := &MapScalars{
KeyBools: map[bool]string{true: "true", false: "false"},
KeyInt32s: map[int32]string{0: "zero", -1: "one", 2: "two"},
KeyInt64s: map[int64]string{0: "zero", -10: "ten", 20: "twenty"},
@ -555,10 +582,10 @@ func TestMapScalars(t *testing.T) {
MyStrings4: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
MyBytes3: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
MyBytes4: MapStrings{"s1": "s1", "s2": "s2"},
})
}
wantFS := want.KnownFields()
testMessage(t, nil, mi.MessageOf(&MapScalars{}), messageOps{
testMessage(t, nil, &MapScalars{}, messageOps{
hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, 23: false, 24: false, 25: false},
getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19), 21: emptyFS.Get(21), 23: emptyFS.Get(23), 25: emptyFS.Get(25)},
setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19), 21: wantFS.Get(21), 23: wantFS.Get(23), 25: wantFS.Get(25)},
@ -687,7 +714,40 @@ type (
}
)
func (*OneofScalars) XXX_OneofFuncs() (func(proto.Message, *proto.Buffer) error, func(proto.Message, int, int, *proto.Buffer) (bool, error), func(proto.Message) int, []interface{}) {
var oneofScalarsType = MessageType{Type: ptype.GoMessage(
mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: "ScalarProto2",
Fields: []ptype.Field{
{Name: "f1", Number: 1, Cardinality: pref.Optional, Kind: pref.BoolKind, Default: V(bool(true)), OneofName: "union"},
{Name: "f2", Number: 2, Cardinality: pref.Optional, Kind: pref.Int32Kind, Default: V(int32(2)), OneofName: "union"},
{Name: "f3", Number: 3, Cardinality: pref.Optional, Kind: pref.Int64Kind, Default: V(int64(3)), OneofName: "union"},
{Name: "f4", Number: 4, Cardinality: pref.Optional, Kind: pref.Uint32Kind, Default: V(uint32(4)), OneofName: "union"},
{Name: "f5", Number: 5, Cardinality: pref.Optional, Kind: pref.Uint64Kind, Default: V(uint64(5)), OneofName: "union"},
{Name: "f6", Number: 6, Cardinality: pref.Optional, Kind: pref.FloatKind, Default: V(float32(6)), OneofName: "union"},
{Name: "f7", Number: 7, Cardinality: pref.Optional, Kind: pref.DoubleKind, Default: V(float64(7)), OneofName: "union"},
{Name: "f8", Number: 8, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("8")), OneofName: "union"},
{Name: "f9", Number: 9, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("9")), OneofName: "union"},
{Name: "f10", Number: 10, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("10")), OneofName: "union"},
{Name: "f11", Number: 11, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("11")), OneofName: "union"},
{Name: "f12", Number: 12, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("12")), OneofName: "union"},
{Name: "f13", Number: 13, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("13")), OneofName: "union"},
},
Oneofs: []ptype.Oneof{{Name: "union"}},
}),
func(pref.MessageType) pref.ProtoMessage {
return new(OneofScalars)
},
)}
func (m *OneofScalars) Type() pref.MessageType { return oneofScalarsType.Type }
func (m *OneofScalars) KnownFields() pref.KnownFields { return oneofScalarsType.KnownFieldsOf(m) }
func (m *OneofScalars) UnknownFields() pref.UnknownFields { return oneofScalarsType.UnknownFieldsOf(m) }
func (m *OneofScalars) Interface() pref.ProtoMessage { return m }
func (m *OneofScalars) ProtoReflect() pref.Message { return m }
func (m *OneofScalars) ProtoMutable() {}
func (*OneofScalars) XXX_OneofFuncs() (func(protoV1.Message, *protoV1.Buffer) error, func(protoV1.Message, int, int, *protoV1.Buffer) (bool, error), func(protoV1.Message) int, []interface{}) {
return nil, nil, nil, []interface{}{
(*OneofScalars_Bool)(nil),
(*OneofScalars_Int32)(nil),
@ -720,43 +780,22 @@ func (*OneofScalars_BytesA) isOneofScalars_Union() {}
func (*OneofScalars_BytesB) isOneofScalars_Union() {}
func TestOneofs(t *testing.T) {
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: "ScalarProto2",
Fields: []ptype.Field{
{Name: "f1", Number: 1, Cardinality: pref.Optional, Kind: pref.BoolKind, Default: V(bool(true)), OneofName: "union"},
{Name: "f2", Number: 2, Cardinality: pref.Optional, Kind: pref.Int32Kind, Default: V(int32(2)), OneofName: "union"},
{Name: "f3", Number: 3, Cardinality: pref.Optional, Kind: pref.Int64Kind, Default: V(int64(3)), OneofName: "union"},
{Name: "f4", Number: 4, Cardinality: pref.Optional, Kind: pref.Uint32Kind, Default: V(uint32(4)), OneofName: "union"},
{Name: "f5", Number: 5, Cardinality: pref.Optional, Kind: pref.Uint64Kind, Default: V(uint64(5)), OneofName: "union"},
{Name: "f6", Number: 6, Cardinality: pref.Optional, Kind: pref.FloatKind, Default: V(float32(6)), OneofName: "union"},
{Name: "f7", Number: 7, Cardinality: pref.Optional, Kind: pref.DoubleKind, Default: V(float64(7)), OneofName: "union"},
{Name: "f8", Number: 8, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("8")), OneofName: "union"},
{Name: "f9", Number: 9, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("9")), OneofName: "union"},
{Name: "f10", Number: 10, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("10")), OneofName: "union"},
{Name: "f11", Number: 11, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("11")), OneofName: "union"},
{Name: "f12", Number: 12, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("12")), OneofName: "union"},
{Name: "f13", Number: 13, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("13")), OneofName: "union"},
},
Oneofs: []ptype.Oneof{{Name: "union"}},
})}
empty := &OneofScalars{}
want1 := &OneofScalars{Union: &OneofScalars_Bool{true}}
want2 := &OneofScalars{Union: &OneofScalars_Int32{20}}
want3 := &OneofScalars{Union: &OneofScalars_Int64{30}}
want4 := &OneofScalars{Union: &OneofScalars_Uint32{40}}
want5 := &OneofScalars{Union: &OneofScalars_Uint64{50}}
want6 := &OneofScalars{Union: &OneofScalars_Float32{60}}
want7 := &OneofScalars{Union: &OneofScalars_Float64{70}}
want8 := &OneofScalars{Union: &OneofScalars_String{string("80")}}
want9 := &OneofScalars{Union: &OneofScalars_StringA{[]byte("90")}}
want10 := &OneofScalars{Union: &OneofScalars_StringB{MyString("100")}}
want11 := &OneofScalars{Union: &OneofScalars_Bytes{[]byte("110")}}
want12 := &OneofScalars{Union: &OneofScalars_BytesA{string("120")}}
want13 := &OneofScalars{Union: &OneofScalars_BytesB{MyBytes("130")}}
empty := mi.MessageOf(&OneofScalars{})
want1 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Bool{true}})
want2 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Int32{20}})
want3 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Int64{30}})
want4 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Uint32{40}})
want5 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Uint64{50}})
want6 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Float32{60}})
want7 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Float64{70}})
want8 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_String{string("80")}})
want9 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_StringA{[]byte("90")}})
want10 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_StringB{MyString("100")}})
want11 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Bytes{[]byte("110")}})
want12 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_BytesA{string("120")}})
want13 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_BytesB{MyBytes("130")}})
testMessage(t, nil, mi.MessageOf(&OneofScalars{}), messageOps{
testMessage(t, nil, &OneofScalars{}, messageOps{
hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false},
getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("13"))},
@ -789,13 +828,6 @@ var cmpOpts = cmp.Options{
cmp.Transformer("UnwrapValue", func(v pref.Value) interface{} {
return v.Interface()
}),
cmp.Transformer("UnwrapMessage", func(m pref.Message) interface{} {
v := m.Interface()
if v, ok := v.(interface{ Unwrap() interface{} }); ok {
return v.Unwrap()
}
return v
}),
cmp.Transformer("UnwrapVector", func(v pref.Vector) interface{} {
return v.(interface{ Unwrap() interface{} }).Unwrap()
}),

View File

@ -53,17 +53,23 @@ var (
// protoc-gen-go historically generated to be able to automatically wrap some
// v1 messages generated by other forks of protoc-gen-go.
func NewConverter(t reflect.Type, k pref.Kind) Converter {
return NewLegacyConverter(t, k, nil, nil)
return NewLegacyConverter(t, k, nil, nil, nil)
}
// Legacy enums and messages do not self-report their own protoreflect types.
// Thus, the caller needs to provide functions for retrieving those when
// a v1 enum or message is encountered.
type (
enumTypeOf = func(reflect.Type) pref.EnumType
messageTypeOf = func(reflect.Type) pref.MessageType
messageValueOf = func(reflect.Value) pref.ProtoMessage
)
// NewLegacyConverter is identical to NewConverter,
// but supports wrapping legacy v1 messages to implement the v2 message API
// using the provided wrapEnum and wrapMessage functions.
// using the provided enumTypeOf, messageTypeOf and messageValueOf functions.
// The wrapped message must implement Unwrapper.
func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapEnum func(reflect.Value) pref.ProtoEnum, wrapMessage func(reflect.Value) pref.ProtoMessage) Converter {
if (wrapEnum == nil) != (wrapMessage == nil) {
panic("legacy enum and message wrappers must both be populated or nil")
}
func NewLegacyConverter(t reflect.Type, k pref.Kind, etOf enumTypeOf, mtOf messageTypeOf, mvOf messageValueOf) Converter {
switch k {
case pref.BoolKind:
if t.Kind() == reflect.Bool {
@ -125,8 +131,8 @@ func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapEnum func(reflect.Value
}
// Handle v1 enums, which we identify as simply a named int32 type.
if wrapEnum != nil && t.PkgPath() != "" && t.Kind() == reflect.Int32 {
et := wrapEnum(reflect.Zero(t)).ProtoReflect().Type()
if etOf != nil && t.PkgPath() != "" && t.Kind() == reflect.Int32 {
et := etOf(t)
return Converter{
PBValueOf: func(v reflect.Value) pref.Value {
if v.Type() != t {
@ -164,14 +170,14 @@ func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapEnum func(reflect.Value
}
// Handle v1 messages, which we need to wrap as a v2 message.
if wrapMessage != nil && t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) {
mt := wrapMessage(reflect.New(t.Elem())).ProtoReflect().Type()
if mtOf != nil && t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) {
mt := mtOf(t)
return Converter{
PBValueOf: func(v reflect.Value) pref.Value {
if v.Type() != t {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
}
return pref.ValueOf(wrapMessage(v).ProtoReflect())
return pref.ValueOf(mvOf(v).ProtoReflect())
},
GoValueOf: func(v pref.Value) reflect.Value {
rv := reflect.ValueOf(v.Message().(Unwrapper).Unwrap())

View File

@ -77,7 +77,7 @@ func (ms mapReflect) Range(f func(pref.MapKey, pref.Value) bool) {
}
}
func (ms mapReflect) Unwrap() interface{} {
return ms.v.Interface()
return ms.v.Addr().Interface()
}
func (ms mapReflect) ProtoMutable() {}

View File

@ -53,7 +53,7 @@ func (vs vectorReflect) Truncate(i int) {
vs.v.Set(vs.v.Slice(0, i))
}
func (vs vectorReflect) Unwrap() interface{} {
return vs.v.Interface()
return vs.v.Addr().Interface()
}
func (vs vectorReflect) ProtoMutable() {}

View File

@ -446,20 +446,21 @@ type ExtensionType interface {
// t.GoType() == reflect.TypeOf(t.InterfaceOf(t.ValueOf(t.New())))
GoType() reflect.Type
// TODO: How do we reconcile GoType with the existing extension API,
// which returns *T for scalars (causing unnecessary aliasing),
// and []T for vectors (causing insufficient aliasing)?
// TODO: What to do with nil?
// Should ValueOf(nil) return Value{}?
// Should InterfaceOf(Value{}) return nil?
// Similarly, should the Value.{Message,Vector,Map} also return nil?
// ValueOf wraps the input and returns it as a Value.
// ValueOf panics if the input value is not the appropriate type.
// ValueOf panics if the input value is invalid or not the appropriate type.
//
// ValueOf is more extensive than protoreflect.ValueOf for a given field's
// value as it has more type information available.
ValueOf(interface{}) Value
// InterfaceOf completely unwraps the Value to the underlying Go type.
// InterfaceOf panics if the input does not represent the appropriate
// underlying Go type.
// InterfaceOf panics if the input is nil or does not represent the
// appropriate underlying Go type.
//
// InterfaceOf is able to unwrap the Value further than Value.Interface
// as it has more type information available.

View File

@ -19,22 +19,24 @@ func GoEnum(ed protoreflect.EnumDescriptor, fn func(protoreflect.EnumType, proto
if ed.IsPlaceholder() {
panic("enum descriptor must not be a placeholder")
}
t := &goEnum{EnumDescriptor: ed, new: fn}
t.typ = reflect.TypeOf(fn(t, 0))
return t
return &goEnum{EnumDescriptor: ed, new: fn}
}
type goEnum struct {
protoreflect.EnumDescriptor
typ reflect.Type
new func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum
once sync.Once
typ reflect.Type
}
func (t *goEnum) GoType() reflect.Type {
t.New(0) // initialize t.typ
return t.typ
}
func (t *goEnum) New(n protoreflect.EnumNumber) protoreflect.ProtoEnum {
e := t.new(t, n)
t.once.Do(func() { t.typ = reflect.TypeOf(e) })
if t.typ != reflect.TypeOf(e) {
panic(fmt.Sprintf("mismatching types for enum: got %T, want %v", e, t.typ))
}
@ -47,22 +49,26 @@ func GoMessage(md protoreflect.MessageDescriptor, fn func(protoreflect.MessageTy
if md.IsPlaceholder() {
panic("message descriptor must not be a placeholder")
}
t := &goMessage{MessageDescriptor: md, new: fn}
t.typ = reflect.TypeOf(fn(t))
return t
// NOTE: Avoid calling fn in the constructor since fn itself may depend on
// this function returning (for cyclic message dependencies).
return &goMessage{MessageDescriptor: md, new: fn}
}
type goMessage struct {
protoreflect.MessageDescriptor
typ reflect.Type
new func(protoreflect.MessageType) protoreflect.ProtoMessage
once sync.Once
typ reflect.Type
}
func (t *goMessage) GoType() reflect.Type {
t.New() // initialize t.typ
return t.typ
}
func (t *goMessage) New() protoreflect.ProtoMessage {
m := t.new(t)
t.once.Do(func() { t.typ = reflect.TypeOf(m) })
if t.typ != reflect.TypeOf(m) {
panic(fmt.Sprintf("mismatching types for message: got %T, want %v", m, t.typ))
}
@ -240,11 +246,11 @@ func (t *goExtension) lazyInit() {
t.valueOf = func(v interface{}) protoreflect.Value {
return protoreflect.ValueOf(value.VectorOf(v, c))
}
t.interfaceOf = func(v protoreflect.Value) interface{} {
t.interfaceOf = func(pv protoreflect.Value) interface{} {
// TODO: Can we assume that Vector implementations know how
// to unwrap themselves?
// Should this be part of the public API in protoreflect?
return v.Vector().(value.Unwrapper).Unwrap()
return pv.Vector().(value.Unwrapper).Unwrap()
}
default:
panic(fmt.Sprintf("invalid cardinality: %v", t.Cardinality()))