diff --git a/internal/impl/legacy_enum.go b/internal/impl/legacy_enum.go index 87b4921b..bba130db 100644 --- a/internal/impl/legacy_enum.go +++ b/internal/impl/legacy_enum.go @@ -16,29 +16,36 @@ import ( ptype "github.com/golang/protobuf/v2/reflect/prototype" ) +// legacyWrapEnum wraps v as a protoreflect.ProtoEnum, +// where v must be a *struct kind and not implement the v2 API already. +func legacyWrapEnum(v reflect.Value) pref.ProtoEnum { + et := legacyLoadEnumType(v.Type()) + return et.New(pref.EnumNumber(v.Int())) +} + var enumTypeCache sync.Map // map[reflect.Type]protoreflect.EnumType -// wrapLegacyEnum wraps v as a protoreflect.ProtoEnum, -// where v must be an int32 kind and not implement the v2 API already. -func wrapLegacyEnum(v reflect.Value) pref.ProtoEnum { +// legacyLoadEnumType dynamically loads a protoreflect.EnumType for t, +// where t must be an int32 kind and not implement the v2 API already. +func legacyLoadEnumType(t reflect.Type) pref.EnumType { // Fast-path: check if a EnumType is cached for this concrete type. - if et, ok := enumTypeCache.Load(v.Type()); ok { - return et.(pref.EnumType).New(pref.EnumNumber(v.Int())) + if et, ok := enumTypeCache.Load(t); ok { + return et.(pref.EnumType) } // Slow-path: derive enum descriptor and initialize EnumType. var m sync.Map // map[protoreflect.EnumNumber]proto.Enum - ed := loadEnumDesc(v.Type()) + ed := legacyLoadEnumDesc(t) et := ptype.GoEnum(ed, func(et pref.EnumType, n pref.EnumNumber) pref.ProtoEnum { if e, ok := m.Load(n); ok { return e.(pref.ProtoEnum) } - e := &legacyEnumWrapper{num: n, pbTyp: et, goTyp: v.Type()} + e := &legacyEnumWrapper{num: n, pbTyp: et, goTyp: t} m.Store(n, e) return e }) - enumTypeCache.Store(v.Type(), et) - return et.(pref.EnumType).New(pref.EnumNumber(v.Int())) + enumTypeCache.Store(t, et) + return et.(pref.EnumType) } type legacyEnumWrapper struct { @@ -70,9 +77,11 @@ var ( var enumDescCache sync.Map // map[reflect.Type]protoreflect.EnumDescriptor -// loadEnumDesc returns an EnumDescriptor derived from the Go type, +var enumNumberType = reflect.TypeOf(pref.EnumNumber(0)) + +// legacyLoadEnumDesc returns an EnumDescriptor derived from the Go type, // which must be an int32 kind and not implement the v2 API already. -func loadEnumDesc(t reflect.Type) pref.EnumDescriptor { +func legacyLoadEnumDesc(t reflect.Type) pref.EnumDescriptor { // Fast-path: check if an EnumDescriptor is cached for this concrete type. if v, ok := enumDescCache.Load(t); ok { return v.(pref.EnumDescriptor) @@ -82,6 +91,9 @@ func loadEnumDesc(t reflect.Type) pref.EnumDescriptor { if t.Kind() != reflect.Int32 || t.PkgPath() == "" { panic(fmt.Sprintf("got %v, want named int32 kind", t)) } + if t == enumNumberType { + panic(fmt.Sprintf("cannot be %v", t)) + } // Derive the enum descriptor from the raw descriptor proto. e := new(ptype.StandaloneEnum) @@ -91,7 +103,7 @@ func loadEnumDesc(t reflect.Type) pref.EnumDescriptor { } if ed, ok := ev.(legacyEnum); ok { b, idxs := ed.EnumDescriptor() - fd := loadFileDesc(b) + fd := legacyLoadFileDesc(b) // Derive syntax. switch fd.GetSyntax() { diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go new file mode 100644 index 00000000..cff89f7e --- /dev/null +++ b/internal/impl/legacy_extension.go @@ -0,0 +1,353 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "fmt" + "reflect" + + protoV1 "github.com/golang/protobuf/proto" + ptag "github.com/golang/protobuf/v2/internal/encoding/tag" + pvalue "github.com/golang/protobuf/v2/internal/value" + pref "github.com/golang/protobuf/v2/reflect/protoreflect" + ptype "github.com/golang/protobuf/v2/reflect/prototype" +) + +func makeLegacyExtensionFieldsFunc(t reflect.Type) func(p *messageDataType) pref.KnownFields { + f := makeLegacyExtensionMapFunc(t) + if f == nil { + return nil + } + return func(p *messageDataType) pref.KnownFields { + return legacyExtensionFields{p.mi, f(p)} + } +} + +type legacyExtensionFields struct { + mi *MessageType + x legacyExtensionIface +} + +func (p legacyExtensionFields) Len() (n int) { + p.x.Range(func(num pref.FieldNumber, _ legacyExtensionEntry) bool { + if p.Has(num) { + n++ + } + return true + }) + return n +} + +func (p legacyExtensionFields) Has(n pref.FieldNumber) bool { + x := p.x.Get(n) + if x.val == nil { + return false + } + t := legacyExtensionTypeOf(x.desc) + if t.Cardinality() == pref.Repeated { + return legacyExtensionValueOf(x.val, t).Vector().Len() > 0 + } + return true +} + +func (p legacyExtensionFields) Get(n pref.FieldNumber) pref.Value { + x := p.x.Get(n) + if x.desc == nil { + return pref.Value{} + } + t := legacyExtensionTypeOf(x.desc) + if x.val == nil { + if t.Cardinality() == pref.Repeated { + // TODO: What is the zero value for Vectors? + // TODO: This logic is racy. + v := t.ValueOf(t.New()) + x.val = legacyExtensionInterfaceOf(v, t) + p.x.Set(n, x) + return v + } + if t.Kind() == pref.MessageKind || t.Kind() == pref.GroupKind { + // TODO: What is the zero value for Messages? + return pref.Value{} + } + return t.Default() + } + return legacyExtensionValueOf(x.val, t) +} + +func (p legacyExtensionFields) Set(n pref.FieldNumber, v pref.Value) { + x := p.x.Get(n) + if x.desc == nil { + panic("no extension descriptor registered") + } + t := legacyExtensionTypeOf(x.desc) + x.val = legacyExtensionInterfaceOf(v, t) + p.x.Set(n, x) +} + +func (p legacyExtensionFields) Clear(n pref.FieldNumber) { + x := p.x.Get(n) + if x.desc == nil { + return + } + x.val = nil + p.x.Set(n, x) +} + +func (p legacyExtensionFields) Mutable(n pref.FieldNumber) pref.Mutable { + x := p.x.Get(n) + if x.desc == nil { + panic("no extension descriptor registered") + } + t := legacyExtensionTypeOf(x.desc) + if x.val == nil { + v := t.ValueOf(t.New()) + x.val = legacyExtensionInterfaceOf(v, t) + p.x.Set(n, x) + } + return legacyExtensionValueOf(x.val, t).Interface().(pref.Mutable) +} + +func (p legacyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) { + p.x.Range(func(n pref.FieldNumber, x legacyExtensionEntry) bool { + if p.Has(n) { + return f(n, p.Get(n)) + } + return true + }) +} + +func (p legacyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { + return legacyExtensionTypes(p) +} + +type legacyExtensionTypes legacyExtensionFields + +func (p legacyExtensionTypes) Len() (n int) { + p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool { + if x.desc != nil { + n++ + } + return true + }) + return n +} + +func (p legacyExtensionTypes) Register(t pref.ExtensionType) { + if p.mi.Type.FullName() != t.ExtendedType().FullName() { + panic("extended type mismatch") + } + if !p.mi.Type.ExtensionRanges().Has(t.Number()) { + panic("invalid extension field number") + } + x := p.x.Get(t.Number()) + if x.desc != nil { + panic("extension descriptor already registered") + } + x.desc = legacyExtensionDescOf(t, p.mi.goType) + p.x.Set(t.Number(), x) +} + +func (p legacyExtensionTypes) Remove(t pref.ExtensionType) { + if !p.mi.Type.ExtensionRanges().Has(t.Number()) { + return + } + x := p.x.Get(t.Number()) + if x.val != nil { + panic("value for extension descriptor still populated") + } + x.desc = nil + if len(x.raw) == 0 { + p.x.Clear(t.Number()) + } else { + p.x.Set(t.Number(), x) + } +} + +func (p legacyExtensionTypes) ByNumber(n pref.FieldNumber) pref.ExtensionType { + x := p.x.Get(n) + if x.desc != nil { + return legacyExtensionTypeOf(x.desc) + } + return nil +} + +func (p legacyExtensionTypes) ByName(s pref.FullName) (t pref.ExtensionType) { + p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool { + if x.desc != nil && x.desc.Name == string(s) { + t = legacyExtensionTypeOf(x.desc) + return false + } + return true + }) + return t +} + +func (p legacyExtensionTypes) Range(f func(pref.ExtensionType) bool) { + p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool { + if x.desc != nil { + if !f(legacyExtensionTypeOf(x.desc)) { + return false + } + } + return true + }) +} + +func legacyExtensionDescOf(t pref.ExtensionType, parent reflect.Type) *protoV1.ExtensionDesc { + if t, ok := t.(*legacyExtensionType); ok { + return t.desc + } + + // Determine the v1 extension type, which is unfortunately not the same as + // the v2 ExtensionType.GoType. + extType := t.GoType() + switch extType.Kind() { + case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: + extType = reflect.PtrTo(extType) // T -> *T for singular scalar fields + case reflect.Ptr: + if extType.Elem().Kind() == reflect.Slice { + extType = extType.Elem() // *[]T -> []T for repeated fields + } + } + + // Reconstruct the legacy enum full name, which is an odd mixture of the + // proto package name with the Go type name. + var enumName string + if t.Kind() == pref.EnumKind { + enumName = t.GoType().Name() + for d, ok := pref.Descriptor(t.EnumType()), true; ok; d, ok = d.Parent() { + if fd, _ := d.(pref.FileDescriptor); fd != nil && fd.Package() != "" { + enumName = string(fd.Package()) + "." + enumName + } + } + } + + // Construct and return a v1 ExtensionDesc. + return &protoV1.ExtensionDesc{ + ExtendedType: reflect.Zero(parent).Interface().(protoV1.Message), + ExtensionType: reflect.Zero(extType).Interface(), + Field: int32(t.Number()), + Name: string(t.FullName()), + Tag: ptag.Marshal(t, enumName), + } +} + +func legacyExtensionTypeOf(d *protoV1.ExtensionDesc) pref.ExtensionType { + // TODO: Add a field to protoV1.ExtensionDesc to contain a v2 descriptor. + + // Derive basic field information from the struct tag. + t := reflect.TypeOf(d.ExtensionType) + isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct + isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 + if isOptional || isRepeated { + t = t.Elem() + } + f := ptag.Unmarshal(d.Tag, t) + + // Construct a v2 ExtensionType. + conv := newConverter(t, f.Kind) + xd, err := ptype.NewExtension(&ptype.StandaloneExtension{ + FullName: pref.FullName(d.Name), + Number: pref.FieldNumber(d.Field), + Cardinality: f.Cardinality, + Kind: f.Kind, + Default: f.Default, + Options: f.Options, + EnumType: conv.EnumType, + MessageType: conv.MessageType, + ExtendedType: legacyLoadMessageDesc(reflect.TypeOf(d.ExtendedType)), + }) + if err != nil { + panic(err) + } + xt := ptype.GoExtension(xd, conv.EnumType, conv.MessageType) + + // Return the extension type as is if the dependencies already support v2. + xt2 := &legacyExtensionType{ExtensionType: xt, desc: d} + if !conv.IsLegacy { + return xt2 + } + + // If the dependency is a v1 enum or message, we need to create a custom + // extension type where ExtensionType.GoType continues to use the legacy + // v1 Go type, instead of the wrapped versions that satisfy the v2 API. + if xd.Cardinality() != pref.Repeated { + // Custom extension type for singular enums and messages. + // The legacy wrappers use legacyEnumWrapper and legacyMessageWrapper + // to implement the v2 interfaces for enums and messages. + // Both of those type satisfy the value.Unwrapper interface. + xt2.typ = t + xt2.new = func() interface{} { + return xt.New().(pvalue.Unwrapper).Unwrap() + } + xt2.valueOf = func(v interface{}) pref.Value { + if reflect.TypeOf(v) != xt2.typ { + panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ)) + } + if xd.Kind() == pref.EnumKind { + return xt.ValueOf(legacyWrapEnum(reflect.ValueOf(v))) + } else { + return xt.ValueOf(legacyWrapMessage(reflect.ValueOf(v))) + } + } + xt2.interfaceOf = func(v pref.Value) interface{} { + return xt.InterfaceOf(v).(pvalue.Unwrapper).Unwrap() + } + } else { + // Custom extension type for repeated enums and messages. + xt2.typ = reflect.PtrTo(reflect.SliceOf(t)) + xt2.new = func() interface{} { + return reflect.New(xt2.typ.Elem()).Interface() + } + xt2.valueOf = func(v interface{}) pref.Value { + if reflect.TypeOf(v) != xt2.typ { + panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ)) + } + return pref.ValueOf(pvalue.VectorOf(v, conv)) + } + xt2.interfaceOf = func(pv pref.Value) interface{} { + v := pv.Vector().(pvalue.Unwrapper).Unwrap() + if reflect.TypeOf(v) != xt2.typ { + panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ)) + } + return v + } + } + return xt2 +} + +type legacyExtensionType struct { + pref.ExtensionType + desc *protoV1.ExtensionDesc + typ reflect.Type + new func() interface{} + valueOf func(interface{}) pref.Value + interfaceOf func(pref.Value) interface{} +} + +func (x *legacyExtensionType) GoType() reflect.Type { + if x.typ != nil { + return x.typ + } + return x.ExtensionType.GoType() +} +func (x *legacyExtensionType) New() interface{} { + if x.new != nil { + return x.new() + } + return x.ExtensionType.New() +} +func (x *legacyExtensionType) ValueOf(v interface{}) pref.Value { + if x.valueOf != nil { + return x.valueOf(v) + } + return x.ExtensionType.ValueOf(v) +} +func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} { + if x.interfaceOf != nil { + return x.interfaceOf(v) + } + return x.ExtensionType.InterfaceOf(v) +} diff --git a/internal/impl/legacy_extension_hack.go b/internal/impl/legacy_extension_hack.go index b2295814..295798ff 100644 --- a/internal/impl/legacy_extension_hack.go +++ b/internal/impl/legacy_extension_hack.go @@ -13,10 +13,13 @@ import ( pref "github.com/golang/protobuf/v2/reflect/protoreflect" ) -// TODO: The logic below this is a hack since v1 currently exposes no -// exported functionality for interacting with these data structures. -// Eventually make changes to v1 such that v2 can access the necessary -// fields without relying on unsafe. +// TODO: The logic in the file is a hack and should be in the v1 repository. +// We need to break the dependency on proto v1 since it is v1 that will +// eventually need to depend on v2. + +// TODO: The v1 API currently exposes no exported functionality for interacting +// with the extension data structures. We will need to make changes in v1 so +// that v2 can access these data structures without relying on unsafe. var ( extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil)) @@ -25,8 +28,10 @@ var ( type legacyExtensionIface interface { Len() int + Has(pref.FieldNumber) bool Get(pref.FieldNumber) legacyExtensionEntry Set(pref.FieldNumber, legacyExtensionEntry) + Clear(pref.FieldNumber) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) } @@ -49,6 +54,13 @@ func makeLegacyExtensionMapFunc(t reflect.Type) func(*messageDataType) legacyExt } } +// TODO: We currently don't do locking with legacyExtensionSyncMap.p.mu. +// The locking behavior was already obscure "feature" beforehand, +// and it is not obvious how it translates to the v2 API. +// The v2 API presents a Range method, which calls a user provided function, +// which may in turn call other methods on the map. In such a use case, +// acquiring a lock within each method would result in a reentrant deadlock. + // legacyExtensionSyncMap is identical to protoV1.XXX_InternalExtensions. // It implements legacyExtensionIface. type legacyExtensionSyncMap struct { @@ -62,16 +74,15 @@ func (m legacyExtensionSyncMap) Len() int { if m.p == nil { return 0 } - m.p.mu.Lock() - defer m.p.mu.Unlock() return m.p.m.Len() } +func (m legacyExtensionSyncMap) Has(n pref.FieldNumber) bool { + return m.p.m.Has(n) +} func (m legacyExtensionSyncMap) Get(n pref.FieldNumber) legacyExtensionEntry { if m.p == nil { return legacyExtensionEntry{} } - m.p.mu.Lock() - defer m.p.mu.Unlock() return m.p.m.Get(n) } func (m *legacyExtensionSyncMap) Set(n pref.FieldNumber, x legacyExtensionEntry) { @@ -81,16 +92,15 @@ func (m *legacyExtensionSyncMap) Set(n pref.FieldNumber, x legacyExtensionEntry) m legacyExtensionMap }) } - m.p.mu.Lock() - defer m.p.mu.Unlock() m.p.m.Set(n, x) } +func (m legacyExtensionSyncMap) Clear(n pref.FieldNumber) { + m.p.m.Clear(n) +} func (m legacyExtensionSyncMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) { if m.p == nil { return } - m.p.mu.Lock() - defer m.p.mu.Unlock() m.p.m.Range(f) } @@ -101,6 +111,10 @@ type legacyExtensionMap map[pref.FieldNumber]legacyExtensionEntry func (m legacyExtensionMap) Len() int { return len(m) } +func (m legacyExtensionMap) Has(n pref.FieldNumber) bool { + _, ok := m[n] + return ok +} func (m legacyExtensionMap) Get(n pref.FieldNumber) legacyExtensionEntry { return m[n] } @@ -110,6 +124,9 @@ func (m *legacyExtensionMap) Set(n pref.FieldNumber, x legacyExtensionEntry) { } (*m)[n] = x } +func (m *legacyExtensionMap) Clear(n pref.FieldNumber) { + delete(*m, n) +} func (m legacyExtensionMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) { for n, x := range m { if !f(n, x) { @@ -124,3 +141,76 @@ type legacyExtensionEntry struct { val interface{} raw []byte } + +// TODO: The legacyExtensionInterfaceOf and legacyExtensionValueOf converters +// exist since the current storage representation in the v1 data structures use +// *T for scalars and []T for repeated fields, but the v2 API operates on +// T for scalars and *[]T for repeated fields. +// +// Instead of maintaining this technical debt in the v2 repository, +// we can offload this into the v1 implementation such that it uses a +// storage representation that is appropriate for v2, and uses the these +// functions to present the illusion that that the underlying storage +// is still *T and []T. +// +// See https://github.com/golang/protobuf/pull/746 +const hasPR746 = true + +// legacyExtensionInterfaceOf converts a protoreflect.Value to the +// storage representation used in v1 extension data structures. +// +// In particular, it represents scalars (except []byte) a pointer to the value, +// and repeated fields as the a slice value itself. +func legacyExtensionInterfaceOf(pv pref.Value, t pref.ExtensionType) interface{} { + v := t.InterfaceOf(pv) + if !hasPR746 { + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: + // Represent primitive types as a pointer to the value. + rv2 := reflect.New(rv.Type()) + rv2.Elem().Set(rv) + v = rv2.Interface() + case reflect.Ptr: + // Represent pointer to slice types as the value itself. + switch rv.Type().Elem().Kind() { + case reflect.Slice: + if rv.IsNil() { + v = reflect.Zero(rv.Type().Elem()).Interface() + } else { + v = rv.Elem().Interface() + } + } + } + } + return v +} + +// legacyExtensionValueOf converts the storage representation of a value in +// the v1 extension data structures to a protoreflect.Value. +// +// In particular, it represents scalars as the value itself, +// and repeated fields as a pointer to the slice value. +func legacyExtensionValueOf(v interface{}, t pref.ExtensionType) pref.Value { + if !hasPR746 { + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Ptr: + // Represent slice types as the value itself. + switch rv.Type().Elem().Kind() { + case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: + if rv.IsNil() { + v = reflect.Zero(rv.Type().Elem()).Interface() + } else { + v = rv.Elem().Interface() + } + } + case reflect.Slice: + // Represent slice types (except []byte) as a pointer to the value. + if rv.Type().Elem().Kind() != reflect.Uint8 { + rv2 := reflect.New(rv.Type()) + rv2.Elem().Set(rv) + v = rv2.Interface() + } + } + } + return t.ValueOf(v) +} diff --git a/internal/impl/legacy_file.go b/internal/impl/legacy_file.go index 9a123c2b..8aaa9363 100644 --- a/internal/impl/legacy_file.go +++ b/internal/impl/legacy_file.go @@ -36,13 +36,13 @@ type ( var fileDescCache sync.Map // map[*byte]*descriptorV1.FileDescriptorProto -// loadFileDesc unmarshals b as a compressed FileDescriptorProto message. +// legacyLoadFileDesc unmarshals b as a compressed FileDescriptorProto message. // // This assumes that b is immutable and that b does not refer to part of a // concatenated series of GZIP files (which would require shenanigans that // rely on the concatenation properties of both protobufs and GZIP). // File descriptors generated by protoc-gen-go do not rely on that property. -func loadFileDesc(b []byte) *descriptorV1.FileDescriptorProto { +func legacyLoadFileDesc(b []byte) *descriptorV1.FileDescriptorProto { // Fast-path: check whether we already have a cached file descriptor. if v, ok := fileDescCache.Load(&b[0]); ok { return v.(*descriptorV1.FileDescriptorProto) diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go index 3c43bfab..4b643590 100644 --- a/internal/impl/legacy_message.go +++ b/internal/impl/legacy_message.go @@ -19,26 +19,38 @@ import ( ptype "github.com/golang/protobuf/v2/reflect/prototype" ) +// legacyWrapMessage wraps v as a protoreflect.ProtoMessage, +// where v must be a *struct kind and not implement the v2 API already. +func legacyWrapMessage(v reflect.Value) pref.ProtoMessage { + mt := legacyLoadMessageType(v.Type()) + return (*legacyMessageWrapper)(mt.dataTypeOf(v.Interface())) +} + var messageTypeCache sync.Map // map[reflect.Type]*MessageType -// wrapLegacyMessage wraps v as a protoreflect.ProtoMessage, -// where v must be a *struct kind and not implement the v2 API already. -func wrapLegacyMessage(v reflect.Value) pref.ProtoMessage { +// legacyLoadMessageType dynamically loads a *MessageType for t, +// where t must be a *struct kind and not implement the v2 API already. +func legacyLoadMessageType(t reflect.Type) *MessageType { // Fast-path: check if a MessageType is cached for this concrete type. - if mt, ok := messageTypeCache.Load(v.Type()); ok { - return mt.(*MessageType).MessageOf(v.Interface()).Interface() + if mt, ok := messageTypeCache.Load(t); ok { + return mt.(*MessageType) } // Slow-path: derive message descriptor and initialize MessageType. - mt := &MessageType{Desc: loadMessageDesc(v.Type())} - messageTypeCache.Store(v.Type(), mt) - return mt.MessageOf(v.Interface()).Interface() + md := legacyLoadMessageDesc(t) + mt := new(MessageType) + mt.Type = ptype.GoMessage(md, func(pref.MessageType) pref.ProtoMessage { + p := reflect.New(t.Elem()).Interface() + return (*legacyMessageWrapper)(mt.dataTypeOf(p)) + }) + messageTypeCache.Store(t, mt) + return mt } type legacyMessageWrapper messageDataType func (m *legacyMessageWrapper) Type() pref.MessageType { - return m.mi.pbType + return m.mi.Type } func (m *legacyMessageWrapper) KnownFields() pref.KnownFields { return (*knownFields)(m) @@ -65,9 +77,9 @@ var ( var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor -// loadMessageDesc returns an MessageDescriptor derived from the Go type, +// legacyLoadMessageDesc returns an MessageDescriptor derived from the Go type, // which must be a *struct kind and not implement the v2 API already. -func loadMessageDesc(t reflect.Type) pref.MessageDescriptor { +func legacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor { return messageDescSet{}.Load(t) } @@ -126,7 +138,7 @@ func (ms *messageDescSet) processMessage(t reflect.Type) pref.MessageDescriptor } if md, ok := mv.(legacyMessage); ok { b, idxs := md.Descriptor() - fd := loadFileDesc(b) + fd := legacyLoadFileDesc(b) // Derive syntax. switch fd.GetSyntax() { @@ -231,7 +243,7 @@ func (ms *messageDescSet) parseField(tag, tagKey, tagVal string, goType reflect. if ev, ok := reflect.Zero(t).Interface().(pref.ProtoEnum); ok { f.EnumType = ev.ProtoReflect().Type() } else { - f.EnumType = loadEnumDesc(t) + f.EnumType = legacyLoadEnumDesc(t) } } if f.MessageType == nil && (f.Kind == pref.MessageKind || f.Kind == pref.GroupKind) { diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go index 379615ba..9f8456ed 100644 --- a/internal/impl/legacy_test.go +++ b/internal/impl/legacy_test.go @@ -30,7 +30,7 @@ import ( ) func mustLoadFileDesc(b []byte, _ []int) pref.FileDescriptor { - fd, err := ptype.NewFileFromDescriptorProto(loadFileDesc(b), nil) + fd, err := ptype.NewFileFromDescriptorProto(legacyLoadFileDesc(b), nil) if err != nil { panic(err) } @@ -42,286 +42,286 @@ func TestLegacyDescriptor(t *testing.T) { fileDescP2_20160225 := mustLoadFileDesc(new(proto2_20160225.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto2_20160225.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160225.SiblingEnum(0))), want: fileDescP2_20160225.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto2_20160225.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160225.Message_ChildEnum(0))), want: fileDescP2_20160225.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.SiblingMessage))), want: fileDescP2_20160225.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ChildMessage))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message))), want: fileDescP2_20160225.Messages().ByName("Message"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_NamedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_NamedGroup))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("NamedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OptionalGroup))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("OptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RequiredGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RequiredGroup))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("RequiredGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_RepeatedGroup))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("RepeatedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OneofGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_OneofGroup))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("OneofGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionOptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionOptionalGroup))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionRepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160225.Message_ExtensionRepeatedGroup))), want: fileDescP2_20160225.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"), }}...) fileDescP3_20160225 := mustLoadFileDesc(new(proto3_20160225.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto3_20160225.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160225.SiblingEnum(0))), want: fileDescP3_20160225.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto3_20160225.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160225.Message_ChildEnum(0))), want: fileDescP3_20160225.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20160225.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160225.SiblingMessage))), want: fileDescP3_20160225.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message_ChildMessage))), want: fileDescP3_20160225.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160225.Message))), want: fileDescP3_20160225.Messages().ByName("Message"), }}...) fileDescP2_20160519 := mustLoadFileDesc(new(proto2_20160519.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto2_20160519.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160519.SiblingEnum(0))), want: fileDescP2_20160519.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto2_20160519.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20160519.Message_ChildEnum(0))), want: fileDescP2_20160519.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.SiblingMessage))), want: fileDescP2_20160519.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ChildMessage))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message))), want: fileDescP2_20160519.Messages().ByName("Message"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_NamedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_NamedGroup))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("NamedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OptionalGroup))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("OptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RequiredGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RequiredGroup))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("RequiredGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_RepeatedGroup))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("RepeatedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OneofGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_OneofGroup))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("OneofGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionOptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionOptionalGroup))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionRepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20160519.Message_ExtensionRepeatedGroup))), want: fileDescP2_20160519.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"), }}...) fileDescP3_20160519 := mustLoadFileDesc(new(proto3_20160519.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto3_20160519.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160519.SiblingEnum(0))), want: fileDescP3_20160519.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto3_20160519.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20160519.Message_ChildEnum(0))), want: fileDescP3_20160519.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20160519.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160519.SiblingMessage))), want: fileDescP3_20160519.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message_ChildMessage))), want: fileDescP3_20160519.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20160519.Message))), want: fileDescP3_20160519.Messages().ByName("Message"), }}...) fileDescP2_20180125 := mustLoadFileDesc(new(proto2_20180125.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto2_20180125.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180125.SiblingEnum(0))), want: fileDescP2_20180125.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto2_20180125.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180125.Message_ChildEnum(0))), want: fileDescP2_20180125.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.SiblingMessage))), want: fileDescP2_20180125.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ChildMessage))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message))), want: fileDescP2_20180125.Messages().ByName("Message"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_NamedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_NamedGroup))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("NamedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OptionalGroup))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("OptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RequiredGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RequiredGroup))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("RequiredGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_RepeatedGroup))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("RepeatedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OneofGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_OneofGroup))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("OneofGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionOptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionOptionalGroup))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionRepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180125.Message_ExtensionRepeatedGroup))), want: fileDescP2_20180125.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"), }}...) fileDescP3_20180125 := mustLoadFileDesc(new(proto3_20180125.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto3_20180125.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180125.SiblingEnum(0))), want: fileDescP3_20180125.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto3_20180125.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180125.Message_ChildEnum(0))), want: fileDescP3_20180125.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180125.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180125.SiblingMessage))), want: fileDescP3_20180125.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message_ChildMessage))), want: fileDescP3_20180125.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180125.Message))), want: fileDescP3_20180125.Messages().ByName("Message"), }}...) fileDescP2_20180430 := mustLoadFileDesc(new(proto2_20180430.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto2_20180430.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180430.SiblingEnum(0))), want: fileDescP2_20180430.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto2_20180430.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180430.Message_ChildEnum(0))), want: fileDescP2_20180430.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.SiblingMessage))), want: fileDescP2_20180430.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ChildMessage))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message))), want: fileDescP2_20180430.Messages().ByName("Message"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_NamedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_NamedGroup))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("NamedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OptionalGroup))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("OptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RequiredGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RequiredGroup))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("RequiredGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_RepeatedGroup))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("RepeatedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OneofGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_OneofGroup))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("OneofGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionOptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionOptionalGroup))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionRepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180430.Message_ExtensionRepeatedGroup))), want: fileDescP2_20180430.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"), }}...) fileDescP3_20180430 := mustLoadFileDesc(new(proto3_20180430.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto3_20180430.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180430.SiblingEnum(0))), want: fileDescP3_20180430.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto3_20180430.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180430.Message_ChildEnum(0))), want: fileDescP3_20180430.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180430.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180430.SiblingMessage))), want: fileDescP3_20180430.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message_ChildMessage))), want: fileDescP3_20180430.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180430.Message))), want: fileDescP3_20180430.Messages().ByName("Message"), }}...) fileDescP2_20180814 := mustLoadFileDesc(new(proto2_20180814.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto2_20180814.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180814.SiblingEnum(0))), want: fileDescP2_20180814.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto2_20180814.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto2_20180814.Message_ChildEnum(0))), want: fileDescP2_20180814.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.SiblingMessage))), want: fileDescP2_20180814.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ChildMessage))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message))), want: fileDescP2_20180814.Messages().ByName("Message"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_NamedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_NamedGroup))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("NamedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OptionalGroup))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("OptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RequiredGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RequiredGroup))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("RequiredGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_RepeatedGroup))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("RepeatedGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OneofGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_OneofGroup))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("OneofGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionOptionalGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionOptionalGroup))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("ExtensionOptionalGroup"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionRepeatedGroup))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto2_20180814.Message_ExtensionRepeatedGroup))), want: fileDescP2_20180814.Messages().ByName("Message").Messages().ByName("ExtensionRepeatedGroup"), }}...) fileDescP3_20180814 := mustLoadFileDesc(new(proto3_20180814.Message).Descriptor()) tests = append(tests, []struct{ got, want pref.Descriptor }{{ - got: loadEnumDesc(reflect.TypeOf(proto3_20180814.SiblingEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180814.SiblingEnum(0))), want: fileDescP3_20180814.Enums().ByName("SiblingEnum"), }, { - got: loadEnumDesc(reflect.TypeOf(proto3_20180814.Message_ChildEnum(0))), + got: legacyLoadEnumDesc(reflect.TypeOf(proto3_20180814.Message_ChildEnum(0))), want: fileDescP3_20180814.Messages().ByName("Message").Enums().ByName("ChildEnum"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180814.SiblingMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180814.SiblingMessage))), want: fileDescP3_20180814.Messages().ByName("SiblingMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message_ChildMessage))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message_ChildMessage))), want: fileDescP3_20180814.Messages().ByName("Message").Messages().ByName("ChildMessage"), }, { - got: loadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message))), + got: legacyLoadMessageDesc(reflect.TypeOf(new(proto3_20180814.Message))), want: fileDescP3_20180814.Messages().ByName("Message"), }}...) @@ -383,13 +383,16 @@ func TestLegacyDescriptor(t *testing.T) { } } -type legacyUnknownMessage struct { +type legacyTestMessage struct { XXX_unrecognized []byte protoV1.XXX_InternalExtensions } -func (*legacyUnknownMessage) ExtensionRangeArray() []protoV1.ExtensionRange { - return []protoV1.ExtensionRange{{Start: 10, End: 20}, {Start: 40, End: 80}} +func (*legacyTestMessage) Reset() {} +func (*legacyTestMessage) String() string { return "" } +func (*legacyTestMessage) ProtoMessage() {} +func (*legacyTestMessage) ExtensionRangeArray() []protoV1.ExtensionRange { + return []protoV1.ExtensionRange{{Start: 10, End: 20}, {Start: 40, End: 80}, {Start: 10000, End: 20000}} } func TestLegacyUnknown(t *testing.T) { @@ -422,8 +425,8 @@ func TestLegacyUnknown(t *testing.T) { return out } - m := new(legacyUnknownMessage) - fs := new(MessageType).MessageOf(m).UnknownFields() + m := new(legacyTestMessage) + fs := MessageOf(m).UnknownFields() if got, want := fs.Len(), 0; got != want { t.Errorf("Len() = %d, want %d", got, want) @@ -618,3 +621,284 @@ func TestLegacyUnknown(t *testing.T) { return i < 2 }) } + +func TestLegactExtensions(t *testing.T) { + extensions := []pref.ExtensionType{ + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: (*bool)(nil), + Field: 10000, + Name: "fizz.buzz.optional_bool", + Tag: "varint,10000,opt,name=optional_bool,json=optionalBool,def=1", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: (*int32)(nil), + Field: 10001, + Name: "fizz.buzz.optional_int32", + Tag: "varint,10001,opt,name=optional_int32,json=optionalInt32,def=-12345", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: (*uint32)(nil), + Field: 10002, + Name: "fizz.buzz.optional_uint32", + Tag: "varint,10002,opt,name=optional_uint32,json=optionalUint32,def=3200", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: (*float32)(nil), + Field: 10003, + Name: "fizz.buzz.optional_float", + Tag: "fixed32,10003,opt,name=optional_float,json=optionalFloat,def=3.14159", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: (*string)(nil), + Field: 10004, + Name: "fizz.buzz.optional_string", + Tag: "bytes,10004,opt,name=optional_string,json=optionalString,def=hello, \"world!\"\n", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]byte)(nil), + Field: 10005, + Name: "fizz.buzz.optional_bytes", + Tag: "bytes,10005,opt,name=optional_bytes,json=optionalBytes,def=dead\\336\\255\\276\\357beef", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: (*proto2_20180125.Message_ChildEnum)(nil), + Field: 10006, + Name: "fizz.buzz.optional_enum", + Tag: "varint,10006,opt,name=optional_enum,json=optionalEnum,enum=google.golang.org.proto2_20180125.Message_ChildEnum,def=0", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: (*proto2_20180125.Message_ChildMessage)(nil), + Field: 10007, + Name: "fizz.buzz.optional_message", + Tag: "bytes,10007,opt,name=optional_message,json=optionalMessage", + Filename: "fizz/buzz/test.proto", + }), + // TODO: Test v2 enum and messages. + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]bool)(nil), + Field: 10008, + Name: "fizz.buzz.repeated_bool", + Tag: "varint,10008,rep,name=repeated_bool,json=repeatedBool", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]int32)(nil), + Field: 10009, + Name: "fizz.buzz.repeated_int32", + Tag: "varint,10009,rep,name=repeated_int32,json=repeatedInt32", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]uint32)(nil), + Field: 10010, + Name: "fizz.buzz.repeated_uint32", + Tag: "varint,10010,rep,name=repeated_uint32,json=repeatedUint32", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]float32)(nil), + Field: 10011, + Name: "fizz.buzz.repeated_float", + Tag: "fixed32,10011,rep,name=repeated_float,json=repeatedFloat", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]string)(nil), + Field: 10012, + Name: "fizz.buzz.repeated_string", + Tag: "bytes,10012,rep,name=repeated_string,json=repeatedString", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([][]byte)(nil), + Field: 10013, + Name: "fizz.buzz.repeated_bytes", + Tag: "bytes,10013,rep,name=repeated_bytes,json=repeatedBytes", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]proto2_20180125.Message_ChildEnum)(nil), + Field: 10014, + Name: "fizz.buzz.repeated_enum", + Tag: "varint,10014,rep,name=repeated_enum,json=repeatedEnum,enum=google.golang.org.proto2_20180125.Message_ChildEnum", + Filename: "fizz/buzz/test.proto", + }), + legacyExtensionTypeOf(&protoV1.ExtensionDesc{ + ExtendedType: (*legacyTestMessage)(nil), + ExtensionType: ([]*proto2_20180125.Message_ChildMessage)(nil), + Field: 10015, + Name: "fizz.buzz.repeated_message", + Tag: "bytes,10015,rep,name=repeated_message,json=repeatedMessage", + Filename: "fizz/buzz/test.proto", + }), + // TODO: Test v2 enum and messages. + } + opts := cmp.Options{cmp.Comparer(func(x, y *proto2_20180125.Message_ChildMessage) bool { + return x == y // pointer compare messages for object identity + })} + + m := new(legacyTestMessage) + fs := MessageOf(m).KnownFields() + ts := fs.ExtensionTypes() + + if n := fs.Len(); n != 0 { + t.Errorf("KnownFields.Len() = %v, want 0", n) + } + if n := ts.Len(); n != 0 { + t.Errorf("ExtensionFieldTypes.Len() = %v, want 0", n) + } + + // Register all the extension types. + for _, xt := range extensions { + ts.Register(xt) + } + + // Check that getting the zero value returns the default value for scalars, + // nil for singular messages, and an empty vector for repeated fields. + defaultValues := []interface{}{ + bool(true), + int32(-12345), + uint32(3200), + float32(3.14159), + string("hello, \"world!\"\n"), + []byte("dead\xde\xad\xbe\xefbeef"), + proto2_20180125.Message_ALPHA, + nil, + new([]bool), + new([]int32), + new([]uint32), + new([]float32), + new([]string), + new([][]byte), + new([]proto2_20180125.Message_ChildEnum), + new([]*proto2_20180125.Message_ChildMessage), + } + for i, xt := range extensions { + var got interface{} + v := fs.Get(xt.Number()) + if xt.Cardinality() != pref.Repeated && xt.Kind() == pref.MessageKind { + got = v.Interface() + } else { + got = xt.InterfaceOf(v) // TODO: Simplify this if InterfaceOf allows nil + } + want := defaultValues[i] + if diff := cmp.Diff(want, got, opts); diff != "" { + t.Errorf("KnownFields.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff) + } + } + + // All fields should be unpopulated. + for _, xt := range extensions { + if fs.Has(xt.Number()) { + t.Errorf("KnownFields.Has(%d) = true, want false", xt.Number()) + } + } + + // Set some values and append to values to the vectors. + m1 := &proto2_20180125.Message_ChildMessage{F1: protoV1.String("m1")} + m2 := &proto2_20180125.Message_ChildMessage{F1: protoV1.String("m2")} + setValues := []interface{}{ + bool(false), + int32(-54321), + uint32(6400), + float32(2.71828), + string("goodbye, \"world!\"\n"), + []byte("live\xde\xad\xbe\xefchicken"), + proto2_20180125.Message_CHARLIE, + m1, + &[]bool{true}, + &[]int32{-1000}, + &[]uint32{1280}, + &[]float32{1.6180}, + &[]string{"zero"}, + &[][]byte{[]byte("zero")}, + &[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO}, + &[]*proto2_20180125.Message_ChildMessage{m2}, + } + for i, xt := range extensions { + fs.Set(xt.Number(), xt.ValueOf(setValues[i])) + } + for i, xt := range extensions[len(extensions)/2:] { + v := extensions[i].ValueOf(setValues[i]) + fs.Get(xt.Number()).Vector().Append(v) + } + + // Get the values and check for equality. + getValues := []interface{}{ + bool(false), + int32(-54321), + uint32(6400), + float32(2.71828), + string("goodbye, \"world!\"\n"), + []byte("live\xde\xad\xbe\xefchicken"), + proto2_20180125.Message_ChildEnum(proto2_20180125.Message_CHARLIE), + m1, + &[]bool{true, false}, + &[]int32{-1000, -54321}, + &[]uint32{1280, 6400}, + &[]float32{1.6180, 2.71828}, + &[]string{"zero", "goodbye, \"world!\"\n"}, + &[][]byte{[]byte("zero"), []byte("live\xde\xad\xbe\xefchicken")}, + &[]proto2_20180125.Message_ChildEnum{proto2_20180125.Message_BRAVO, proto2_20180125.Message_CHARLIE}, + &[]*proto2_20180125.Message_ChildMessage{m2, m1}, + } + for i, xt := range extensions { + got := xt.InterfaceOf(fs.Get(xt.Number())) + want := getValues[i] + if diff := cmp.Diff(want, got, opts); diff != "" { + t.Errorf("KnownFields.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff) + } + } + + if n := fs.Len(); n != 16 { + t.Errorf("KnownFields.Len() = %v, want 0", n) + } + if n := ts.Len(); n != 16 { + t.Errorf("ExtensionFieldTypes.Len() = %v, want 16", n) + } + + // Clear the field for all extension types. + for _, xt := range extensions { + fs.Clear(xt.Number()) + } + if n := fs.Len(); n != 0 { + t.Errorf("KnownFields.Len() = %v, want 0", n) + } + if n := ts.Len(); n != 16 { + t.Errorf("ExtensionFieldTypes.Len() = %v, want 16", n) + } + + // De-register all extension types. + for _, xt := range extensions { + ts.Remove(xt) + } + if n := fs.Len(); n != 0 { + t.Errorf("KnownFields.Len() = %v, want 0", n) + } + if n := ts.Len(); n != 0 { + t.Errorf("ExtensionFieldTypes.Len() = %v, want 0", n) + } + +} diff --git a/internal/impl/legacy_unknown.go b/internal/impl/legacy_unknown.go index 172ce709..e0cb0358 100644 --- a/internal/impl/legacy_unknown.go +++ b/internal/impl/legacy_unknown.go @@ -29,7 +29,7 @@ func makeLegacyUnknownFieldsFunc(t reflect.Type) func(p *messageDataType) pref.U if extFunc != nil { return func(p *messageDataType) pref.UnknownFields { return &legacyUnknownBytesAndExtensionMap{ - unkFunc(p), extFunc(p), p.mi.Desc.ExtensionRanges(), + unkFunc(p), extFunc(p), p.mi.Type.ExtensionRanges(), } } } diff --git a/internal/impl/message.go b/internal/impl/message.go index ab889cce..57ca93a6 100644 --- a/internal/impl/message.go +++ b/internal/impl/message.go @@ -12,24 +12,30 @@ import ( "sync" pref "github.com/golang/protobuf/v2/reflect/protoreflect" - ptype "github.com/golang/protobuf/v2/reflect/prototype" ) +// MessageOf returns the protoreflect.Message interface over p. +// If p already implements proto.Message, then it directly calls the +// ProtoReflect method, otherwise it wraps the legacy v1 message to implement +// the v2 reflective interface. +func MessageOf(p interface{}) pref.Message { + if m, ok := p.(pref.ProtoMessage); ok { + return m.ProtoReflect() + } + return legacyWrapMessage(reflect.ValueOf(p)).ProtoReflect() +} + // MessageType provides protobuf related functionality for a given Go type // that represents a message. A given instance of MessageType is tied to // exactly one Go type, which must be a pointer to a struct type. type MessageType struct { - // Desc is an optionally provided message descriptor. If nil, the descriptor - // is lazily derived from the Go type information of generated messages - // for the v1 API. - // + // Type is the underlying message type and must be populated. // Once set, this field must never be mutated. - Desc pref.MessageDescriptor + Type pref.MessageType once sync.Once // protects all unexported fields - goType reflect.Type // pointer to struct - pbType pref.MessageType // only valid if goType does not implement proto.Message + goType reflect.Type // pointer to struct // TODO: Split fields into dense and sparse maps similar to the current // table-driven implementation in v1? @@ -45,33 +51,12 @@ type MessageType struct { // It must be called at the start of every exported method. func (mi *MessageType) init(p interface{}) { mi.once.Do(func() { - v := reflect.ValueOf(p) - t := v.Type() + t := reflect.TypeOf(p) if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct { panic(fmt.Sprintf("got %v, want *struct kind", t)) } mi.goType = t - // Derive the message descriptor if unspecified. - if mi.Desc == nil { - mi.Desc = loadMessageDesc(t) - } - - // Initialize the Go message type wrapper if the Go type does not - // implement the proto.Message interface. - // - // Otherwise, we assume that the Go type manually implements the - // interface and is internally consistent such that: - // goType == reflect.New(goType.Elem()).Interface().(proto.Message).ProtoReflect().Type().GoType() - // - // Generated code ensures that this property holds. - if _, ok := p.(pref.ProtoMessage); !ok { - mi.pbType = ptype.GoMessage(mi.Desc, func(pref.MessageType) pref.ProtoMessage { - p := reflect.New(t.Elem()).Interface() - return (*legacyMessageWrapper)(mi.dataTypeOf(p)) - }) - } - mi.makeKnownFieldsFunc(t.Elem()) mi.makeUnknownFieldsFunc(t.Elem()) mi.makeExtensionFieldsFunc(t.Elem()) @@ -86,10 +71,9 @@ func (mi *MessageType) init(p interface{}) { } } -// makeKnownFieldsFunc generates per-field functions for all operations -// to be performed on each field. It takes in a reflect.Type representing the -// Go struct, and a protoreflect.MessageDescriptor to match with the fields -// in the struct. +// makeKnownFieldsFunc generates functions for operations that can be performed +// on each protobuf message field. It takes in a reflect.Type representing the +// Go struct and matches message fields with struct fields. // // This code assumes that the struct is well-formed and panics if there are // any discrepancies. @@ -136,8 +120,8 @@ fieldLoop: } mi.fields = map[pref.FieldNumber]*fieldInfo{} - for i := 0; i < mi.Desc.Fields().Len(); i++ { - fd := mi.Desc.Fields().Get(i) + for i := 0; i < mi.Type.Fields().Len(); i++ { + fd := mi.Type.Fields().Get(i) fs := fields[fd.Number()] var fi fieldInfo switch { @@ -169,33 +153,25 @@ func (mi *MessageType) makeUnknownFieldsFunc(t reflect.Type) { } func (mi *MessageType) makeExtensionFieldsFunc(t reflect.Type) { - // TODO + if f := makeLegacyExtensionFieldsFunc(t); f != nil { + mi.extensionFields = f + return + } mi.extensionFields = func(*messageDataType) pref.KnownFields { return emptyExtensionFields{} } } -func (mi *MessageType) MessageOf(p interface{}) pref.Message { - mi.init(p) - if m, ok := p.(pref.ProtoMessage); ok { - // We assume p properly implements protoreflect.Message. - // See the comment in MessageType.init regarding pbType. - return m.ProtoReflect() - } - return (*legacyMessageWrapper)(mi.dataTypeOf(p)) -} - func (mi *MessageType) KnownFieldsOf(p interface{}) pref.KnownFields { - mi.init(p) return (*knownFields)(mi.dataTypeOf(p)) } func (mi *MessageType) UnknownFieldsOf(p interface{}) pref.UnknownFields { - mi.init(p) return mi.unknownFields(mi.dataTypeOf(p)) } func (mi *MessageType) dataTypeOf(p interface{}) *messageDataType { + mi.init(p) return &messageDataType{pointerOfIface(&p), mi} } @@ -249,20 +225,30 @@ func (fs *knownFields) Set(n pref.FieldNumber, v pref.Value) { fi.set(fs.p, v) return } - fs.extensionFields().Set(n, v) + if fs.mi.Type.ExtensionRanges().Has(n) { + fs.extensionFields().Set(n, v) + return + } + panic(fmt.Sprintf("invalid field: %d", n)) } func (fs *knownFields) Clear(n pref.FieldNumber) { if fi := fs.mi.fields[n]; fi != nil { fi.clear(fs.p) return } - fs.extensionFields().Clear(n) + if fs.mi.Type.ExtensionRanges().Has(n) { + fs.extensionFields().Clear(n) + return + } } func (fs *knownFields) Mutable(n pref.FieldNumber) pref.Mutable { if fi := fs.mi.fields[n]; fi != nil { return fi.mutable(fs.p) } - return fs.extensionFields().Mutable(n) + if fs.mi.Type.ExtensionRanges().Has(n) { + return fs.extensionFields().Mutable(n) + } + panic(fmt.Sprintf("invalid field: %d", n)) } func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) { for n, fi := range fs.mi.fields { @@ -291,14 +277,14 @@ func (emptyUnknownFields) IsSupported() bool { r type emptyExtensionFields struct{} -func (emptyExtensionFields) Len() int { return 0 } -func (emptyExtensionFields) Has(pref.FieldNumber) bool { return false } -func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value { return pref.Value{} } -func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value) { panic("extensions not supported") } -func (emptyExtensionFields) Clear(pref.FieldNumber) { return } // noop -func (emptyExtensionFields) Mutable(pref.FieldNumber) pref.Mutable { panic("extensions not supported") } -func (emptyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) { return } -func (emptyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { return emptyExtensionTypes{} } +func (emptyExtensionFields) Len() int { return 0 } +func (emptyExtensionFields) Has(pref.FieldNumber) bool { return false } +func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value { return pref.Value{} } +func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value) { panic("extensions not supported") } +func (emptyExtensionFields) Clear(pref.FieldNumber) { return } // noop +func (emptyExtensionFields) Mutable(pref.FieldNumber) pref.Mutable { panic("extensions not supported") } +func (emptyExtensionFields) Range(func(pref.FieldNumber, pref.Value) bool) { return } +func (emptyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { return emptyExtensionTypes{} } type emptyExtensionTypes struct{} diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go index 80b369c5..7f8d49f9 100644 --- a/internal/impl/message_field.go +++ b/internal/impl/message_field.go @@ -9,7 +9,7 @@ import ( "reflect" "github.com/golang/protobuf/v2/internal/flags" - "github.com/golang/protobuf/v2/internal/value" + pvalue "github.com/golang/protobuf/v2/internal/value" pref "github.com/golang/protobuf/v2/reflect/protoreflect" ) @@ -42,7 +42,7 @@ func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot refle if !reflect.PtrTo(ot).Implements(ft) { panic(fmt.Sprintf("invalid type: %v does not implement %v", ot, ft)) } - conv := value.NewLegacyConverter(ot.Field(0).Type, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage) + conv := newConverter(ot.Field(0).Type, fd.Kind()) fieldOffset := offsetOf(fs) // TODO: Implement unsafe fast path? return fieldInfo{ @@ -106,8 +106,8 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo if ft.Kind() != reflect.Map { panic(fmt.Sprintf("invalid type: got %v, want map kind", ft)) } - keyConv := value.NewLegacyConverter(ft.Key(), fd.MessageType().Fields().ByNumber(1).Kind(), wrapLegacyEnum, wrapLegacyMessage) - valConv := value.NewLegacyConverter(ft.Elem(), fd.MessageType().Fields().ByNumber(2).Kind(), wrapLegacyEnum, wrapLegacyMessage) + keyConv := newConverter(ft.Key(), fd.MessageType().Fields().ByNumber(1).Kind()) + valConv := newConverter(ft.Elem(), fd.MessageType().Fields().ByNumber(2).Kind()) fieldOffset := offsetOf(fs) // TODO: Implement unsafe fast path? return fieldInfo{ @@ -117,11 +117,11 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo }, get: func(p pointer) pref.Value { v := p.apply(fieldOffset).asType(fs.Type).Interface() - return pref.ValueOf(value.MapOf(v, keyConv, valConv)) + return pref.ValueOf(pvalue.MapOf(v, keyConv, valConv)) }, set: func(p pointer, v pref.Value) { rv := p.apply(fieldOffset).asType(fs.Type).Elem() - rv.Set(reflect.ValueOf(v.Map().(value.Unwrapper).Unwrap())) + rv.Set(reflect.ValueOf(v.Map().(pvalue.Unwrapper).Unwrap()).Elem()) }, clear: func(p pointer) { rv := p.apply(fieldOffset).asType(fs.Type).Elem() @@ -129,7 +129,7 @@ func fieldInfoForMap(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo }, mutable: func(p pointer) pref.Mutable { v := p.apply(fieldOffset).asType(fs.Type).Interface() - return value.MapOf(v, keyConv, valConv) + return pvalue.MapOf(v, keyConv, valConv) }, } } @@ -139,7 +139,7 @@ func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn if ft.Kind() != reflect.Slice { panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft)) } - conv := value.NewLegacyConverter(ft.Elem(), fd.Kind(), wrapLegacyEnum, wrapLegacyMessage) + conv := newConverter(ft.Elem(), fd.Kind()) fieldOffset := offsetOf(fs) // TODO: Implement unsafe fast path? return fieldInfo{ @@ -149,11 +149,11 @@ func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn }, get: func(p pointer) pref.Value { v := p.apply(fieldOffset).asType(fs.Type).Interface() - return pref.ValueOf(value.VectorOf(v, conv)) + return pref.ValueOf(pvalue.VectorOf(v, conv)) }, set: func(p pointer, v pref.Value) { rv := p.apply(fieldOffset).asType(fs.Type).Elem() - rv.Set(reflect.ValueOf(v.Vector().(value.Unwrapper).Unwrap())) + rv.Set(reflect.ValueOf(v.Vector().(pvalue.Unwrapper).Unwrap()).Elem()) }, clear: func(p pointer) { rv := p.apply(fieldOffset).asType(fs.Type).Elem() @@ -161,7 +161,7 @@ func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn }, mutable: func(p pointer) pref.Mutable { v := p.apply(fieldOffset).asType(fs.Type).Interface() - return value.VectorOf(v, conv) + return pvalue.VectorOf(v, conv) }, } } @@ -179,7 +179,7 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn ft = ft.Elem() } } - conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage) + conv := newConverter(ft, fd.Kind()) fieldOffset := offsetOf(fs) // TODO: Implement unsafe fast path? return fieldInfo{ @@ -244,7 +244,7 @@ func fieldInfoForScalar(fd pref.FieldDescriptor, fs reflect.StructField) fieldIn func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo { ft := fs.Type - conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage) + conv := newConverter(ft, fd.Kind()) fieldOffset := offsetOf(fs) // TODO: Implement unsafe fast path? return fieldInfo{ @@ -282,3 +282,12 @@ func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldI }, } } + +// newConverter calls value.NewLegacyConverter with the necessary constructor +// functions for legacy enum and message support. +func newConverter(t reflect.Type, k pref.Kind) pvalue.Converter { + messageType := func(t reflect.Type) pref.MessageType { + return legacyLoadMessageType(t).Type + } + return pvalue.NewLegacyConverter(t, k, legacyLoadEnumType, messageType, legacyWrapMessage) +} diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go index 4041f021..3f386710 100644 --- a/internal/impl/message_test.go +++ b/internal/impl/message_test.go @@ -13,7 +13,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/golang/protobuf/proto" protoV1 "github.com/golang/protobuf/proto" descriptorV1 "github.com/golang/protobuf/protoc-gen-go/descriptor" pref "github.com/golang/protobuf/v2/reflect/protoreflect" @@ -46,29 +45,8 @@ type ( MapStrings map[MyString]MyString MapBytes map[MyString]MyBytes - - MyEnumV1 pref.EnumNumber - MyEnumV2 string - myEnumV2 MyEnumV2 - - MyMessageV1 struct { - // SubMessage *Message - } - MyMessageV2 map[pref.FieldNumber]pref.Value - myMessageV2 MyMessageV2 ) -func (e MyEnumV2) ProtoReflect() pref.Enum { return myEnumV2(e) } -func (e myEnumV2) Type() pref.EnumType { return nil } // TODO -func (e myEnumV2) Number() pref.EnumNumber { return 0 } // TODO - -func (m *MyMessageV2) ProtoReflect() pref.Message { return (*myMessageV2)(m) } -func (m *myMessageV2) Type() pref.MessageType { return nil } // TODO -func (m *myMessageV2) KnownFields() pref.KnownFields { return nil } // TODO -func (m *myMessageV2) UnknownFields() pref.UnknownFields { return nil } // TODO -func (m *myMessageV2) Interface() pref.ProtoMessage { return (*MyMessageV2)(m) } -func (m *myMessageV2) ProtoMutable() {} - // List of test operations to perform on messages, vectors, or maps. type ( messageOp interface{} // equalMessage | hasFields | getFields | setFields | clearFields | vectorFields | mapFields @@ -143,8 +121,8 @@ type ScalarProto2 struct { MyBytesA *MyString `protobuf:"22"` } -func TestScalarProto2(t *testing.T) { - mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{ +var scalarProto2Type = MessageType{Type: ptype.GoMessage( + mustMakeMessageDesc(ptype.StandaloneMessage{ Syntax: pref.Proto2, FullName: "ScalarProto2", Fields: []ptype.Field{ @@ -172,9 +150,21 @@ func TestScalarProto2(t *testing.T) { {Name: "f21", Number: 21, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("21"))}, {Name: "f22", Number: 22, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("22"))}, }, - })} + }), + func(pref.MessageType) pref.ProtoMessage { + return new(ScalarProto2) + }, +)} - testMessage(t, nil, mi.MessageOf(&ScalarProto2{}), messageOps{ +func (m *ScalarProto2) Type() pref.MessageType { return scalarProto2Type.Type } +func (m *ScalarProto2) KnownFields() pref.KnownFields { return scalarProto2Type.KnownFieldsOf(m) } +func (m *ScalarProto2) UnknownFields() pref.UnknownFields { return scalarProto2Type.UnknownFieldsOf(m) } +func (m *ScalarProto2) Interface() pref.ProtoMessage { return m } +func (m *ScalarProto2) ProtoReflect() pref.Message { return m } +func (m *ScalarProto2) ProtoMutable() {} + +func TestScalarProto2(t *testing.T) { + testMessage(t, nil, &ScalarProto2{}, messageOps{ hasFields{ 1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, @@ -191,15 +181,15 @@ func TestScalarProto2(t *testing.T) { 1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, }, - equalMessage(mi.MessageOf(&ScalarProto2{ + equalMessage(&ScalarProto2{ new(bool), new(int32), new(int64), new(uint32), new(uint64), new(float32), new(float64), new(string), []byte{}, []byte{}, new(string), new(MyBool), new(MyInt32), new(MyInt64), new(MyUint32), new(MyUint64), new(MyFloat32), new(MyFloat64), new(MyString), MyBytes{}, MyBytes{}, new(MyString), - })), + }), clearFields{ 1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, }, - equalMessage(mi.MessageOf(&ScalarProto2{})), + equalMessage(&ScalarProto2{}), }) } @@ -229,8 +219,8 @@ type ScalarProto3 struct { MyBytesA MyString `protobuf:"22"` } -func TestScalarProto3(t *testing.T) { - mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{ +var scalarProto3Type = MessageType{Type: ptype.GoMessage( + mustMakeMessageDesc(ptype.StandaloneMessage{ Syntax: pref.Proto3, FullName: "ScalarProto3", Fields: []ptype.Field{ @@ -258,9 +248,21 @@ func TestScalarProto3(t *testing.T) { {Name: "f21", Number: 21, Cardinality: pref.Optional, Kind: pref.BytesKind}, {Name: "f22", Number: 22, Cardinality: pref.Optional, Kind: pref.BytesKind}, }, - })} + }), + func(pref.MessageType) pref.ProtoMessage { + return new(ScalarProto3) + }, +)} - testMessage(t, nil, mi.MessageOf(&ScalarProto3{}), messageOps{ +func (m *ScalarProto3) Type() pref.MessageType { return scalarProto3Type.Type } +func (m *ScalarProto3) KnownFields() pref.KnownFields { return scalarProto3Type.KnownFieldsOf(m) } +func (m *ScalarProto3) UnknownFields() pref.UnknownFields { return scalarProto3Type.UnknownFieldsOf(m) } +func (m *ScalarProto3) Interface() pref.ProtoMessage { return m } +func (m *ScalarProto3) ProtoReflect() pref.Message { return m } +func (m *ScalarProto3) ProtoMutable() {} + +func TestScalarProto3(t *testing.T) { + testMessage(t, nil, &ScalarProto3{}, messageOps{ hasFields{ 1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, @@ -277,7 +279,7 @@ func TestScalarProto3(t *testing.T) { 1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, }, - equalMessage(mi.MessageOf(&ScalarProto3{})), + equalMessage(&ScalarProto3{}), setFields{ 1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V([]byte("10")), 11: V([]byte("11")), 12: V(bool(true)), 13: V(int32(13)), 14: V(int64(14)), 15: V(uint32(15)), 16: V(uint64(16)), 17: V(float32(17)), 18: V(float64(18)), 19: V(string("19")), 20: V(string("20")), 21: V([]byte("21")), 22: V([]byte("22")), @@ -286,19 +288,19 @@ func TestScalarProto3(t *testing.T) { 1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, }, - equalMessage(mi.MessageOf(&ScalarProto3{ + equalMessage(&ScalarProto3{ true, 2, 3, 4, 5, 6, 7, "8", []byte("9"), []byte("10"), "11", true, 13, 14, 15, 16, 17, 18, "19", []byte("20"), []byte("21"), "22", - })), + }), clearFields{ 1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, }, - equalMessage(mi.MessageOf(&ScalarProto3{})), + equalMessage(&ScalarProto3{}), }) } -type RepeatedScalars struct { +type ListScalars struct { Bools []bool `protobuf:"1"` Int32s []int32 `protobuf:"2"` Int64s []int64 `protobuf:"3"` @@ -322,10 +324,10 @@ type RepeatedScalars struct { MyBytes4 VectorStrings `protobuf:"19"` } -func TestRepeatedScalars(t *testing.T) { - mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{ +var listScalarsType = MessageType{Type: ptype.GoMessage( + mustMakeMessageDesc(ptype.StandaloneMessage{ Syntax: pref.Proto2, - FullName: "RepeatedScalars", + FullName: "ListScalars", Fields: []ptype.Field{ {Name: "f1", Number: 1, Cardinality: pref.Repeated, Kind: pref.BoolKind}, {Name: "f2", Number: 2, Cardinality: pref.Repeated, Kind: pref.Int32Kind}, @@ -349,12 +351,24 @@ func TestRepeatedScalars(t *testing.T) { {Name: "f18", Number: 18, Cardinality: pref.Repeated, Kind: pref.BytesKind}, {Name: "f19", Number: 19, Cardinality: pref.Repeated, Kind: pref.BytesKind}, }, - })} + }), + func(pref.MessageType) pref.ProtoMessage { + return new(ListScalars) + }, +)} - empty := mi.MessageOf(&RepeatedScalars{}) +func (m *ListScalars) Type() pref.MessageType { return listScalarsType.Type } +func (m *ListScalars) KnownFields() pref.KnownFields { return listScalarsType.KnownFieldsOf(m) } +func (m *ListScalars) UnknownFields() pref.UnknownFields { return listScalarsType.UnknownFieldsOf(m) } +func (m *ListScalars) Interface() pref.ProtoMessage { return m } +func (m *ListScalars) ProtoReflect() pref.Message { return m } +func (m *ListScalars) ProtoMutable() {} + +func TestListScalars(t *testing.T) { + empty := &ListScalars{} emptyFS := empty.KnownFields() - want := mi.MessageOf(&RepeatedScalars{ + want := &ListScalars{ Bools: []bool{true, false, true}, Int32s: []int32{2, math.MinInt32, math.MaxInt32}, Int64s: []int64{3, math.MinInt64, math.MaxInt64}, @@ -376,10 +390,10 @@ func TestRepeatedScalars(t *testing.T) { MyStrings4: VectorBytes{[]byte("17"), nil, []byte("seventeen")}, MyBytes3: VectorBytes{[]byte("18"), nil, []byte("eighteen")}, MyBytes4: VectorStrings{"19", "", "nineteen"}, - }) + } wantFS := want.KnownFields() - testMessage(t, nil, mi.MessageOf(&RepeatedScalars{}), messageOps{ + testMessage(t, nil, &ListScalars{}, messageOps{ hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false}, getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19)}, setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19)}, @@ -470,25 +484,26 @@ type MapScalars struct { MyBytes4 MapStrings `protobuf:"25"` } -func TestMapScalars(t *testing.T) { - mustMakeMapEntry := func(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field { - return ptype.Field{ - Name: pref.Name(fmt.Sprintf("f%d", n)), - Number: n, - Cardinality: pref.Repeated, - Kind: pref.MessageKind, - MessageType: mustMakeMessageDesc(ptype.StandaloneMessage{ - Syntax: pref.Proto2, - FullName: pref.FullName(fmt.Sprintf("MapScalars.F%dEntry", n)), - Fields: []ptype.Field{ - {Name: "key", Number: 1, Cardinality: pref.Optional, Kind: keyKind}, - {Name: "value", Number: 2, Cardinality: pref.Optional, Kind: valKind}, - }, - Options: &descriptorV1.MessageOptions{MapEntry: protoV1.Bool(true)}, - }), - } +func mustMakeMapEntry(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field { + return ptype.Field{ + Name: pref.Name(fmt.Sprintf("f%d", n)), + Number: n, + Cardinality: pref.Repeated, + Kind: pref.MessageKind, + MessageType: mustMakeMessageDesc(ptype.StandaloneMessage{ + Syntax: pref.Proto2, + FullName: pref.FullName(fmt.Sprintf("MapScalars.F%dEntry", n)), + Fields: []ptype.Field{ + {Name: "key", Number: 1, Cardinality: pref.Optional, Kind: keyKind}, + {Name: "value", Number: 2, Cardinality: pref.Optional, Kind: valKind}, + }, + Options: &descriptorV1.MessageOptions{MapEntry: protoV1.Bool(true)}, + }), } - mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{ +} + +var mapScalarsType = MessageType{Type: ptype.GoMessage( + mustMakeMessageDesc(ptype.StandaloneMessage{ Syntax: pref.Proto2, FullName: "MapScalars", Fields: []ptype.Field{ @@ -521,12 +536,24 @@ func TestMapScalars(t *testing.T) { mustMakeMapEntry(24, pref.StringKind, pref.BytesKind), mustMakeMapEntry(25, pref.StringKind, pref.BytesKind), }, - })} + }), + func(pref.MessageType) pref.ProtoMessage { + return new(MapScalars) + }, +)} - empty := mi.MessageOf(&MapScalars{}) +func (m *MapScalars) Type() pref.MessageType { return mapScalarsType.Type } +func (m *MapScalars) KnownFields() pref.KnownFields { return mapScalarsType.KnownFieldsOf(m) } +func (m *MapScalars) UnknownFields() pref.UnknownFields { return mapScalarsType.UnknownFieldsOf(m) } +func (m *MapScalars) Interface() pref.ProtoMessage { return m } +func (m *MapScalars) ProtoReflect() pref.Message { return m } +func (m *MapScalars) ProtoMutable() {} + +func TestMapScalars(t *testing.T) { + empty := &MapScalars{} emptyFS := empty.KnownFields() - want := mi.MessageOf(&MapScalars{ + want := &MapScalars{ KeyBools: map[bool]string{true: "true", false: "false"}, KeyInt32s: map[int32]string{0: "zero", -1: "one", 2: "two"}, KeyInt64s: map[int64]string{0: "zero", -10: "ten", 20: "twenty"}, @@ -555,10 +582,10 @@ func TestMapScalars(t *testing.T) { MyStrings4: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")}, MyBytes3: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")}, MyBytes4: MapStrings{"s1": "s1", "s2": "s2"}, - }) + } wantFS := want.KnownFields() - testMessage(t, nil, mi.MessageOf(&MapScalars{}), messageOps{ + testMessage(t, nil, &MapScalars{}, messageOps{ hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, 23: false, 24: false, 25: false}, getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19), 21: emptyFS.Get(21), 23: emptyFS.Get(23), 25: emptyFS.Get(25)}, setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19), 21: wantFS.Get(21), 23: wantFS.Get(23), 25: wantFS.Get(25)}, @@ -687,7 +714,40 @@ type ( } ) -func (*OneofScalars) XXX_OneofFuncs() (func(proto.Message, *proto.Buffer) error, func(proto.Message, int, int, *proto.Buffer) (bool, error), func(proto.Message) int, []interface{}) { +var oneofScalarsType = MessageType{Type: ptype.GoMessage( + mustMakeMessageDesc(ptype.StandaloneMessage{ + Syntax: pref.Proto2, + FullName: "ScalarProto2", + Fields: []ptype.Field{ + {Name: "f1", Number: 1, Cardinality: pref.Optional, Kind: pref.BoolKind, Default: V(bool(true)), OneofName: "union"}, + {Name: "f2", Number: 2, Cardinality: pref.Optional, Kind: pref.Int32Kind, Default: V(int32(2)), OneofName: "union"}, + {Name: "f3", Number: 3, Cardinality: pref.Optional, Kind: pref.Int64Kind, Default: V(int64(3)), OneofName: "union"}, + {Name: "f4", Number: 4, Cardinality: pref.Optional, Kind: pref.Uint32Kind, Default: V(uint32(4)), OneofName: "union"}, + {Name: "f5", Number: 5, Cardinality: pref.Optional, Kind: pref.Uint64Kind, Default: V(uint64(5)), OneofName: "union"}, + {Name: "f6", Number: 6, Cardinality: pref.Optional, Kind: pref.FloatKind, Default: V(float32(6)), OneofName: "union"}, + {Name: "f7", Number: 7, Cardinality: pref.Optional, Kind: pref.DoubleKind, Default: V(float64(7)), OneofName: "union"}, + {Name: "f8", Number: 8, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("8")), OneofName: "union"}, + {Name: "f9", Number: 9, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("9")), OneofName: "union"}, + {Name: "f10", Number: 10, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("10")), OneofName: "union"}, + {Name: "f11", Number: 11, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("11")), OneofName: "union"}, + {Name: "f12", Number: 12, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("12")), OneofName: "union"}, + {Name: "f13", Number: 13, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("13")), OneofName: "union"}, + }, + Oneofs: []ptype.Oneof{{Name: "union"}}, + }), + func(pref.MessageType) pref.ProtoMessage { + return new(OneofScalars) + }, +)} + +func (m *OneofScalars) Type() pref.MessageType { return oneofScalarsType.Type } +func (m *OneofScalars) KnownFields() pref.KnownFields { return oneofScalarsType.KnownFieldsOf(m) } +func (m *OneofScalars) UnknownFields() pref.UnknownFields { return oneofScalarsType.UnknownFieldsOf(m) } +func (m *OneofScalars) Interface() pref.ProtoMessage { return m } +func (m *OneofScalars) ProtoReflect() pref.Message { return m } +func (m *OneofScalars) ProtoMutable() {} + +func (*OneofScalars) XXX_OneofFuncs() (func(protoV1.Message, *protoV1.Buffer) error, func(protoV1.Message, int, int, *protoV1.Buffer) (bool, error), func(protoV1.Message) int, []interface{}) { return nil, nil, nil, []interface{}{ (*OneofScalars_Bool)(nil), (*OneofScalars_Int32)(nil), @@ -720,43 +780,22 @@ func (*OneofScalars_BytesA) isOneofScalars_Union() {} func (*OneofScalars_BytesB) isOneofScalars_Union() {} func TestOneofs(t *testing.T) { - mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{ - Syntax: pref.Proto2, - FullName: "ScalarProto2", - Fields: []ptype.Field{ - {Name: "f1", Number: 1, Cardinality: pref.Optional, Kind: pref.BoolKind, Default: V(bool(true)), OneofName: "union"}, - {Name: "f2", Number: 2, Cardinality: pref.Optional, Kind: pref.Int32Kind, Default: V(int32(2)), OneofName: "union"}, - {Name: "f3", Number: 3, Cardinality: pref.Optional, Kind: pref.Int64Kind, Default: V(int64(3)), OneofName: "union"}, - {Name: "f4", Number: 4, Cardinality: pref.Optional, Kind: pref.Uint32Kind, Default: V(uint32(4)), OneofName: "union"}, - {Name: "f5", Number: 5, Cardinality: pref.Optional, Kind: pref.Uint64Kind, Default: V(uint64(5)), OneofName: "union"}, - {Name: "f6", Number: 6, Cardinality: pref.Optional, Kind: pref.FloatKind, Default: V(float32(6)), OneofName: "union"}, - {Name: "f7", Number: 7, Cardinality: pref.Optional, Kind: pref.DoubleKind, Default: V(float64(7)), OneofName: "union"}, - {Name: "f8", Number: 8, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("8")), OneofName: "union"}, - {Name: "f9", Number: 9, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("9")), OneofName: "union"}, - {Name: "f10", Number: 10, Cardinality: pref.Optional, Kind: pref.StringKind, Default: V(string("10")), OneofName: "union"}, - {Name: "f11", Number: 11, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("11")), OneofName: "union"}, - {Name: "f12", Number: 12, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("12")), OneofName: "union"}, - {Name: "f13", Number: 13, Cardinality: pref.Optional, Kind: pref.BytesKind, Default: V([]byte("13")), OneofName: "union"}, - }, - Oneofs: []ptype.Oneof{{Name: "union"}}, - })} + empty := &OneofScalars{} + want1 := &OneofScalars{Union: &OneofScalars_Bool{true}} + want2 := &OneofScalars{Union: &OneofScalars_Int32{20}} + want3 := &OneofScalars{Union: &OneofScalars_Int64{30}} + want4 := &OneofScalars{Union: &OneofScalars_Uint32{40}} + want5 := &OneofScalars{Union: &OneofScalars_Uint64{50}} + want6 := &OneofScalars{Union: &OneofScalars_Float32{60}} + want7 := &OneofScalars{Union: &OneofScalars_Float64{70}} + want8 := &OneofScalars{Union: &OneofScalars_String{string("80")}} + want9 := &OneofScalars{Union: &OneofScalars_StringA{[]byte("90")}} + want10 := &OneofScalars{Union: &OneofScalars_StringB{MyString("100")}} + want11 := &OneofScalars{Union: &OneofScalars_Bytes{[]byte("110")}} + want12 := &OneofScalars{Union: &OneofScalars_BytesA{string("120")}} + want13 := &OneofScalars{Union: &OneofScalars_BytesB{MyBytes("130")}} - empty := mi.MessageOf(&OneofScalars{}) - want1 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Bool{true}}) - want2 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Int32{20}}) - want3 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Int64{30}}) - want4 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Uint32{40}}) - want5 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Uint64{50}}) - want6 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Float32{60}}) - want7 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Float64{70}}) - want8 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_String{string("80")}}) - want9 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_StringA{[]byte("90")}}) - want10 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_StringB{MyString("100")}}) - want11 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_Bytes{[]byte("110")}}) - want12 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_BytesA{string("120")}}) - want13 := mi.MessageOf(&OneofScalars{Union: &OneofScalars_BytesB{MyBytes("130")}}) - - testMessage(t, nil, mi.MessageOf(&OneofScalars{}), messageOps{ + testMessage(t, nil, &OneofScalars{}, messageOps{ hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false}, getFields{1: V(bool(true)), 2: V(int32(2)), 3: V(int64(3)), 4: V(uint32(4)), 5: V(uint64(5)), 6: V(float32(6)), 7: V(float64(7)), 8: V(string("8")), 9: V(string("9")), 10: V(string("10")), 11: V([]byte("11")), 12: V([]byte("12")), 13: V([]byte("13"))}, @@ -789,13 +828,6 @@ var cmpOpts = cmp.Options{ cmp.Transformer("UnwrapValue", func(v pref.Value) interface{} { return v.Interface() }), - cmp.Transformer("UnwrapMessage", func(m pref.Message) interface{} { - v := m.Interface() - if v, ok := v.(interface{ Unwrap() interface{} }); ok { - return v.Unwrap() - } - return v - }), cmp.Transformer("UnwrapVector", func(v pref.Vector) interface{} { return v.(interface{ Unwrap() interface{} }).Unwrap() }), diff --git a/internal/value/convert.go b/internal/value/convert.go index e4692a7f..ea29244a 100644 --- a/internal/value/convert.go +++ b/internal/value/convert.go @@ -53,17 +53,23 @@ var ( // protoc-gen-go historically generated to be able to automatically wrap some // v1 messages generated by other forks of protoc-gen-go. func NewConverter(t reflect.Type, k pref.Kind) Converter { - return NewLegacyConverter(t, k, nil, nil) + return NewLegacyConverter(t, k, nil, nil, nil) } +// Legacy enums and messages do not self-report their own protoreflect types. +// Thus, the caller needs to provide functions for retrieving those when +// a v1 enum or message is encountered. +type ( + enumTypeOf = func(reflect.Type) pref.EnumType + messageTypeOf = func(reflect.Type) pref.MessageType + messageValueOf = func(reflect.Value) pref.ProtoMessage +) + // NewLegacyConverter is identical to NewConverter, // but supports wrapping legacy v1 messages to implement the v2 message API -// using the provided wrapEnum and wrapMessage functions. +// using the provided enumTypeOf, messageTypeOf and messageValueOf functions. // The wrapped message must implement Unwrapper. -func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapEnum func(reflect.Value) pref.ProtoEnum, wrapMessage func(reflect.Value) pref.ProtoMessage) Converter { - if (wrapEnum == nil) != (wrapMessage == nil) { - panic("legacy enum and message wrappers must both be populated or nil") - } +func NewLegacyConverter(t reflect.Type, k pref.Kind, etOf enumTypeOf, mtOf messageTypeOf, mvOf messageValueOf) Converter { switch k { case pref.BoolKind: if t.Kind() == reflect.Bool { @@ -125,8 +131,8 @@ func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapEnum func(reflect.Value } // Handle v1 enums, which we identify as simply a named int32 type. - if wrapEnum != nil && t.PkgPath() != "" && t.Kind() == reflect.Int32 { - et := wrapEnum(reflect.Zero(t)).ProtoReflect().Type() + if etOf != nil && t.PkgPath() != "" && t.Kind() == reflect.Int32 { + et := etOf(t) return Converter{ PBValueOf: func(v reflect.Value) pref.Value { if v.Type() != t { @@ -164,14 +170,14 @@ func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapEnum func(reflect.Value } // Handle v1 messages, which we need to wrap as a v2 message. - if wrapMessage != nil && t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) { - mt := wrapMessage(reflect.New(t.Elem())).ProtoReflect().Type() + if mtOf != nil && t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) { + mt := mtOf(t) return Converter{ PBValueOf: func(v reflect.Value) pref.Value { if v.Type() != t { panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t)) } - return pref.ValueOf(wrapMessage(v).ProtoReflect()) + return pref.ValueOf(mvOf(v).ProtoReflect()) }, GoValueOf: func(v pref.Value) reflect.Value { rv := reflect.ValueOf(v.Message().(Unwrapper).Unwrap()) diff --git a/internal/value/map.go b/internal/value/map.go index c0510747..d2ef6592 100644 --- a/internal/value/map.go +++ b/internal/value/map.go @@ -77,7 +77,7 @@ func (ms mapReflect) Range(f func(pref.MapKey, pref.Value) bool) { } } func (ms mapReflect) Unwrap() interface{} { - return ms.v.Interface() + return ms.v.Addr().Interface() } func (ms mapReflect) ProtoMutable() {} diff --git a/internal/value/vector.go b/internal/value/vector.go index 664ece27..d6efd99e 100644 --- a/internal/value/vector.go +++ b/internal/value/vector.go @@ -53,7 +53,7 @@ func (vs vectorReflect) Truncate(i int) { vs.v.Set(vs.v.Slice(0, i)) } func (vs vectorReflect) Unwrap() interface{} { - return vs.v.Interface() + return vs.v.Addr().Interface() } func (vs vectorReflect) ProtoMutable() {} diff --git a/reflect/protoreflect/type.go b/reflect/protoreflect/type.go index 1fcc575c..ceb763c8 100644 --- a/reflect/protoreflect/type.go +++ b/reflect/protoreflect/type.go @@ -446,20 +446,21 @@ type ExtensionType interface { // t.GoType() == reflect.TypeOf(t.InterfaceOf(t.ValueOf(t.New()))) GoType() reflect.Type - // TODO: How do we reconcile GoType with the existing extension API, - // which returns *T for scalars (causing unnecessary aliasing), - // and []T for vectors (causing insufficient aliasing)? + // TODO: What to do with nil? + // Should ValueOf(nil) return Value{}? + // Should InterfaceOf(Value{}) return nil? + // Similarly, should the Value.{Message,Vector,Map} also return nil? // ValueOf wraps the input and returns it as a Value. - // ValueOf panics if the input value is not the appropriate type. + // ValueOf panics if the input value is invalid or not the appropriate type. // // ValueOf is more extensive than protoreflect.ValueOf for a given field's // value as it has more type information available. ValueOf(interface{}) Value // InterfaceOf completely unwraps the Value to the underlying Go type. - // InterfaceOf panics if the input does not represent the appropriate - // underlying Go type. + // InterfaceOf panics if the input is nil or does not represent the + // appropriate underlying Go type. // // InterfaceOf is able to unwrap the Value further than Value.Interface // as it has more type information available. diff --git a/reflect/prototype/go_type.go b/reflect/prototype/go_type.go index 10f1f546..00d440e1 100644 --- a/reflect/prototype/go_type.go +++ b/reflect/prototype/go_type.go @@ -19,22 +19,24 @@ func GoEnum(ed protoreflect.EnumDescriptor, fn func(protoreflect.EnumType, proto if ed.IsPlaceholder() { panic("enum descriptor must not be a placeholder") } - t := &goEnum{EnumDescriptor: ed, new: fn} - t.typ = reflect.TypeOf(fn(t, 0)) - return t + return &goEnum{EnumDescriptor: ed, new: fn} } type goEnum struct { protoreflect.EnumDescriptor - typ reflect.Type new func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum + + once sync.Once + typ reflect.Type } func (t *goEnum) GoType() reflect.Type { + t.New(0) // initialize t.typ return t.typ } func (t *goEnum) New(n protoreflect.EnumNumber) protoreflect.ProtoEnum { e := t.new(t, n) + t.once.Do(func() { t.typ = reflect.TypeOf(e) }) if t.typ != reflect.TypeOf(e) { panic(fmt.Sprintf("mismatching types for enum: got %T, want %v", e, t.typ)) } @@ -47,22 +49,26 @@ func GoMessage(md protoreflect.MessageDescriptor, fn func(protoreflect.MessageTy if md.IsPlaceholder() { panic("message descriptor must not be a placeholder") } - t := &goMessage{MessageDescriptor: md, new: fn} - t.typ = reflect.TypeOf(fn(t)) - return t + // NOTE: Avoid calling fn in the constructor since fn itself may depend on + // this function returning (for cyclic message dependencies). + return &goMessage{MessageDescriptor: md, new: fn} } type goMessage struct { protoreflect.MessageDescriptor - typ reflect.Type new func(protoreflect.MessageType) protoreflect.ProtoMessage + + once sync.Once + typ reflect.Type } func (t *goMessage) GoType() reflect.Type { + t.New() // initialize t.typ return t.typ } func (t *goMessage) New() protoreflect.ProtoMessage { m := t.new(t) + t.once.Do(func() { t.typ = reflect.TypeOf(m) }) if t.typ != reflect.TypeOf(m) { panic(fmt.Sprintf("mismatching types for message: got %T, want %v", m, t.typ)) } @@ -240,11 +246,11 @@ func (t *goExtension) lazyInit() { t.valueOf = func(v interface{}) protoreflect.Value { return protoreflect.ValueOf(value.VectorOf(v, c)) } - t.interfaceOf = func(v protoreflect.Value) interface{} { + t.interfaceOf = func(pv protoreflect.Value) interface{} { // TODO: Can we assume that Vector implementations know how // to unwrap themselves? // Should this be part of the public API in protoreflect? - return v.Vector().(value.Unwrapper).Unwrap() + return pv.Vector().(value.Unwrapper).Unwrap() } default: panic(fmt.Sprintf("invalid cardinality: %v", t.Cardinality()))